mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
support wan2.2 5B I2V
This commit is contained in:
@@ -212,9 +212,16 @@ class DiTBlock(nn.Module):
|
||||
self.gate = GateModule()
|
||||
|
||||
def forward(self, x, context, t_mod, freqs):
|
||||
has_seq = len(t_mod.shape) == 4
|
||||
chunk_dim = 2 if has_seq else 1
|
||||
# msa: multi-head self-attention mlp: multi-layer perceptron
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
|
||||
if has_seq:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
|
||||
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
|
||||
)
|
||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||
x = x + self.cross_attn(self.norm3(x), context)
|
||||
@@ -253,8 +260,12 @@ class Head(nn.Module):
|
||||
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
|
||||
def forward(self, x, t_mod):
|
||||
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
||||
if len(t_mod.shape) == 3:
|
||||
shift, scale = (self.modulation.unsqueeze(0).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(2)).chunk(2, dim=2)
|
||||
x = (self.head(self.norm(x) * (1 + scale.squeeze(2)) + shift.squeeze(2)))
|
||||
else:
|
||||
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(2, dim=1)
|
||||
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
||||
return x
|
||||
|
||||
|
||||
@@ -276,12 +287,14 @@ class WanModel(torch.nn.Module):
|
||||
has_ref_conv: bool = False,
|
||||
add_control_adapter: bool = False,
|
||||
in_dim_control_adapter: int = 24,
|
||||
is_5b: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.freq_dim = freq_dim
|
||||
self.has_image_input = has_image_input
|
||||
self.patch_size = patch_size
|
||||
self.is_5b = is_5b
|
||||
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
@@ -672,6 +685,7 @@ class WanModelStateDictConverter:
|
||||
"num_heads": 24,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"is_5b": True,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
|
||||
Reference in New Issue
Block a user