import torch from einops import rearrange, repeat class TileWorker: def __init__(self): pass def mask(self, height, width, border_width): # Create a mask with shape (height, width). # The centre area is filled with 1, and the border line is filled with values in range (0, 1]. x = torch.arange(height).repeat(width, 1).T y = torch.arange(width).repeat(height, 1) mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values mask = (mask / border_width).clip(0, 1) return mask def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype): # Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num) batch_size, channel, _, _ = model_input.shape model_input = model_input.to(device=tile_device, dtype=tile_dtype) unfold_operator = torch.nn.Unfold( kernel_size=(tile_size, tile_size), stride=(tile_stride, tile_stride) ) model_input = unfold_operator(model_input) model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1)) return model_input def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype): # Call y=forward_fn(x) for each tile tile_num = model_input.shape[-1] model_output_stack = [] for tile_id in range(0, tile_num, tile_batch_size): # process input tile_id_ = min(tile_id + tile_batch_size, tile_num) x = model_input[:, :, :, :, tile_id: tile_id_] x = x.to(device=inference_device, dtype=inference_dtype) x = rearrange(x, "b c h w n -> (n b) c h w") # process output y = forward_fn(x) y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id) y = y.to(device=tile_device, dtype=tile_dtype) model_output_stack.append(y) model_output = torch.concat(model_output_stack, dim=-1) return model_output def io_scale(self, model_output, tile_size): # Determine the size modification happened in forward_fn # We only consider the same scale on height and width. io_scale = model_output.shape[2] / tile_size return io_scale def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype): # The reversed function of tile mask = self.mask(tile_size, tile_size, border_width) mask = mask.to(device=tile_device, dtype=tile_dtype) mask = rearrange(mask, "h w -> 1 1 h w 1") model_output = model_output * mask fold_operator = torch.nn.Fold( output_size=(height, width), kernel_size=(tile_size, tile_size), stride=(tile_stride, tile_stride) ) mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1]) model_output = rearrange(model_output, "b c h w n -> b (c h w) n") model_output = fold_operator(model_output) / fold_operator(mask) return model_output def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None): # Prepare inference_device, inference_dtype = model_input.device, model_input.dtype height, width = model_input.shape[2], model_input.shape[3] border_width = int(tile_stride*0.5) if border_width is None else border_width # tile model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype) # inference model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype) # resize io_scale = self.io_scale(model_output, tile_size) height, width = int(height*io_scale), int(width*io_scale) tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale) border_width = int(border_width*io_scale) # untile model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype) # Done! model_output = model_output.to(device=inference_device, dtype=inference_dtype) return model_output class ConvAttention(torch.nn.Module): def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): super().__init__() dim_inner = head_dim * num_heads kv_dim = kv_dim if kv_dim is not None else q_dim self.num_heads = num_heads self.head_dim = head_dim self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q) self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv) self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv) self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out) def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): if encoder_hidden_states is None: encoder_hidden_states = hidden_states batch_size = encoder_hidden_states.shape[0] conv_input = rearrange(hidden_states, "B L C -> B C L 1") q = self.to_q(conv_input) q = rearrange(q[:, :, :, 0], "B C L -> B L C") conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1") k = self.to_k(conv_input) v = self.to_v(conv_input) k = rearrange(k[:, :, :, 0], "B C L -> B L C") v = rearrange(v[:, :, :, 0], "B C L -> B L C") q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) conv_input = rearrange(hidden_states, "B L C -> B C L 1") hidden_states = self.to_out(conv_input) hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C") return hidden_states class VAEAttentionBlock(torch.nn.Module): def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): super().__init__() inner_dim = num_attention_heads * attention_head_dim self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) self.transformer_blocks = torch.nn.ModuleList([ ConvAttention( inner_dim, num_attention_heads, attention_head_dim, bias_q=True, bias_kv=True, bias_out=True ) for d in range(num_layers) ]) def forward(self, hidden_states, time_emb, text_emb, res_stack): batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) for block in self.transformer_blocks: hidden_states = block(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = hidden_states + residual return hidden_states, time_emb, text_emb, res_stack class ResnetBlock(torch.nn.Module): def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5): super().__init__() self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = torch.nn.SiLU() self.conv_shortcut = None if in_channels != out_channels: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): x = hidden_states x = self.norm1(x) x = self.nonlinearity(x) x = self.conv1(x) if time_emb is not None: emb = self.nonlinearity(time_emb) emb = self.time_emb_proj(emb)[:, :, None, None] x = x + emb x = self.norm2(x) x = self.nonlinearity(x) x = self.conv2(x) if self.conv_shortcut is not None: hidden_states = self.conv_shortcut(hidden_states) hidden_states = hidden_states + x return hidden_states, time_emb, text_emb, res_stack class UpSampler(torch.nn.Module): def __init__(self, channels): super().__init__() self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1) def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") hidden_states = self.conv(hidden_states) return hidden_states, time_emb, text_emb, res_stack class DownSampler(torch.nn.Module): def __init__(self, channels, padding=1, extra_padding=False): super().__init__() self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding) self.extra_padding = extra_padding def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): if self.extra_padding: hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) hidden_states = self.conv(hidden_states) return hidden_states, time_emb, text_emb, res_stack class FluxVAEDecoder(torch.nn.Module): def __init__(self): super().__init__() self.scaling_factor = 0.3611 self.shift_factor = 0.1159 self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x self.blocks = torch.nn.ModuleList([ # UNetMidBlock2D ResnetBlock(512, 512, eps=1e-6), VAEAttentionBlock(1, 512, 512, 1, eps=1e-6), ResnetBlock(512, 512, eps=1e-6), # UpDecoderBlock2D ResnetBlock(512, 512, eps=1e-6), ResnetBlock(512, 512, eps=1e-6), ResnetBlock(512, 512, eps=1e-6), UpSampler(512), # UpDecoderBlock2D ResnetBlock(512, 512, eps=1e-6), ResnetBlock(512, 512, eps=1e-6), ResnetBlock(512, 512, eps=1e-6), UpSampler(512), # UpDecoderBlock2D ResnetBlock(512, 256, eps=1e-6), ResnetBlock(256, 256, eps=1e-6), ResnetBlock(256, 256, eps=1e-6), UpSampler(256), # UpDecoderBlock2D ResnetBlock(256, 128, eps=1e-6), ResnetBlock(128, 128, eps=1e-6), ResnetBlock(128, 128, eps=1e-6), ]) self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6) self.conv_act = torch.nn.SiLU() self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1) def tiled_forward(self, sample, tile_size=64, tile_stride=32): hidden_states = TileWorker().tiled_forward( lambda x: self.forward(x), sample, tile_size, tile_stride, tile_device=sample.device, tile_dtype=sample.dtype ) return hidden_states def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): # For VAE Decoder, we do not need to apply the tiler on each layer. if tiled: return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) # 1. pre-process hidden_states = sample / self.scaling_factor + self.shift_factor hidden_states = self.conv_in(hidden_states) time_emb = None text_emb = None res_stack = None # 2. blocks for i, block in enumerate(self.blocks): hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) # 3. output hidden_states = self.conv_norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states class FluxVAEEncoder(torch.nn.Module): def __init__(self): super().__init__() self.scaling_factor = 0.3611 self.shift_factor = 0.1159 self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1) self.blocks = torch.nn.ModuleList([ # DownEncoderBlock2D ResnetBlock(128, 128, eps=1e-6), ResnetBlock(128, 128, eps=1e-6), DownSampler(128, padding=0, extra_padding=True), # DownEncoderBlock2D ResnetBlock(128, 256, eps=1e-6), ResnetBlock(256, 256, eps=1e-6), DownSampler(256, padding=0, extra_padding=True), # DownEncoderBlock2D ResnetBlock(256, 512, eps=1e-6), ResnetBlock(512, 512, eps=1e-6), DownSampler(512, padding=0, extra_padding=True), # DownEncoderBlock2D ResnetBlock(512, 512, eps=1e-6), ResnetBlock(512, 512, eps=1e-6), # UNetMidBlock2D ResnetBlock(512, 512, eps=1e-6), VAEAttentionBlock(1, 512, 512, 1, eps=1e-6), ResnetBlock(512, 512, eps=1e-6), ]) self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6) self.conv_act = torch.nn.SiLU() self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1) def tiled_forward(self, sample, tile_size=64, tile_stride=32): hidden_states = TileWorker().tiled_forward( lambda x: self.forward(x), sample, tile_size, tile_stride, tile_device=sample.device, tile_dtype=sample.dtype ) return hidden_states def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs): # For VAE Decoder, we do not need to apply the tiler on each layer. if tiled: return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride) # 1. pre-process hidden_states = self.conv_in(sample) time_emb = None text_emb = None res_stack = None # 2. blocks for i, block in enumerate(self.blocks): hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) # 3. output hidden_states = self.conv_norm_out(hidden_states) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) hidden_states = hidden_states[:, :16] hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor return hidden_states def encode_video(self, sample, batch_size=8): B = sample.shape[0] hidden_states = [] for i in range(0, sample.shape[2], batch_size): j = min(i + batch_size, sample.shape[2]) sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W") hidden_states_batch = self(sample_batch) hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B) hidden_states.append(hidden_states_batch) hidden_states = torch.concat(hidden_states, dim=2) return hidden_states