import torch import torch.nn as nn from typing import Tuple, Optional from einops import rearrange from .wan_video_dit import DiTBlock, precompute_freqs_cis_3d, MLP, sinusoidal_embedding_1d from .utils import hash_state_dict_keys class WanControlNetModel(torch.nn.Module): def __init__( self, dim: int, in_dim: int, ffn_dim: int, out_dim: int, text_dim: int, freq_dim: int, eps: float, patch_size: Tuple[int, int, int], num_heads: int, num_layers: int, has_image_input: bool, ): super().__init__() self.dim = dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) self.text_embedding = nn.Sequential( nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim) ) self.time_embedding = nn.Sequential( nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim) ) self.time_projection = nn.Sequential( nn.SiLU(), nn.Linear(dim, dim * 6)) self.blocks = nn.ModuleList([ DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) for _ in range(num_layers) ]) head_dim = dim // num_heads self.freqs = precompute_freqs_cis_3d(head_dim) if has_image_input: self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280 self.controlnet_conv_in = torch.nn.Conv3d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.controlnet_blocks = torch.nn.ModuleList([ torch.nn.Linear(dim, dim, bias=False) for _ in range(num_layers) ]) def patchify(self, x: torch.Tensor): x = self.patch_embedding(x) grid_size = x.shape[2:] x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() return x, grid_size # x, grid_size: (f, h, w) def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): return rearrange( x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', f=grid_size[0], h=grid_size[1], w=grid_size[2], x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] ) def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, controlnet_conditioning: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, **kwargs, ): t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) context = self.text_embedding(context) if self.has_image_input: x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) x = x + self.controlnet_conv_in(controlnet_conditioning) x, (f, h, w) = self.patchify(x) freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward res_stack = [] for block in self.blocks: if self.training and use_gradient_checkpointing: if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: x = block(x, context, t_mod, freqs) res_stack.append(x) controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)] return controlnet_res_stack @staticmethod def state_dict_converter(): return WanControlNetModelStateDictConverter() class WanControlNetModelStateDictConverter: def __init__(self): pass def from_diffusers(self, state_dict): return state_dict def from_civitai(self, state_dict): return state_dict def from_base_model(self, state_dict): if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": config = { "has_image_input": False, "patch_size": [1, 2, 2], "in_dim": 16, "dim": 1536, "ffn_dim": 8960, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 12, "num_layers": 30, "eps": 1e-6 } elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": config = { "has_image_input": False, "patch_size": [1, 2, 2], "in_dim": 16, "dim": 5120, "ffn_dim": 13824, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 40, "num_layers": 40, "eps": 1e-6 } elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": config = { "has_image_input": True, "patch_size": [1, 2, 2], "in_dim": 36, "dim": 5120, "ffn_dim": 13824, "freq_dim": 256, "text_dim": 4096, "out_dim": 16, "num_heads": 40, "num_layers": 40, "eps": 1e-6 } else: config = {} state_dict_ = {} dtype, device = None, None for name, param in state_dict.items(): if name.startswith("head."): continue state_dict_[name] = param dtype, device = param.dtype, param.device for block_id in range(config["num_layers"]): zeros = torch.zeros((config["dim"], config["dim"]), dtype=dtype, device=device) state_dict_[f"controlnet_blocks.{block_id}.weight"] = zeros.clone() state_dict_["controlnet_conv_in.weight"] = torch.zeros((config["in_dim"], config["in_dim"], 1, 1, 1), dtype=dtype, device=device) state_dict_["controlnet_conv_in.bias"] = torch.zeros((config["in_dim"],), dtype=dtype, device=device) return state_dict_, config