diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index da1aafc..650e08f 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -183,6 +183,13 @@ class CrossAttention(nn.Module): return self.o(x) +class GateModule(nn.Module): + def __init__(self,): + super().__init__() + + def forward(self, x, gate, residual): + return x + gate * residual + class DiTBlock(nn.Module): def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6): super().__init__() @@ -199,16 +206,17 @@ class DiTBlock(nn.Module): self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU( approximate='tanh'), nn.Linear(ffn_dim, dim)) self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.gate = GateModule() def forward(self, x, context, t_mod, freqs): # msa: multi-head self-attention mlp: multi-layer perceptron shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) input_x = modulate(self.norm1(x), shift_msa, scale_msa) - x = x + gate_msa * self.self_attn(input_x, freqs) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) x = x + self.cross_attn(self.norm3(x), context) input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) - x = x + gate_mlp * self.ffn(input_x) + x = self.gate(x, gate_mlp, self.ffn(input_x)) return x diff --git a/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py b/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py index b4f5612..77c230c 100644 --- a/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py +++ b/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py @@ -44,11 +44,28 @@ class LitModel(pl.LightningModule): def configure_model(self): tp_mesh = self.device_mesh["tensor_parallel"] + plan = { + "text_embedding.0": ColwiseParallel(), + "text_embedding.2": RowwiseParallel(), + "time_projection.1": ColwiseParallel(output_layouts=Replicate()), + "text_embedding.0": ColwiseParallel(), + "text_embedding.2": RowwiseParallel(), + "blocks.0": PrepareModuleInput( + input_layouts=(Replicate(), None, None, None), + desired_input_layouts=(Replicate(), None, None, None), + ), + "head": PrepareModuleInput( + input_layouts=(Replicate(), None), + desired_input_layouts=(Replicate(), None), + use_local_output=True, + ) + } + self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan) for block_id, block in enumerate(self.pipe.dit.blocks): layer_tp_plan = { "self_attn": PrepareModuleInput( - input_layouts=(Replicate(), Replicate()), - desired_input_layouts=(Replicate(), Shard(0)), + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Shard(1), Shard(0)), ), "self_attn.q": SequenceParallel(), "self_attn.k": SequenceParallel(), @@ -59,11 +76,11 @@ class LitModel(pl.LightningModule): input_layouts=(Shard(1), Shard(1), Shard(1)), desired_input_layouts=(Shard(2), Shard(2), Shard(2)), ), - "self_attn.o": ColwiseParallel(output_layouts=Replicate()), - + "self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()), + "cross_attn": PrepareModuleInput( - input_layouts=(Replicate(), Replicate()), - desired_input_layouts=(Replicate(), Replicate()), + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Shard(1), Replicate()), ), "cross_attn.q": SequenceParallel(), "cross_attn.k": SequenceParallel(), @@ -74,18 +91,26 @@ class LitModel(pl.LightningModule): input_layouts=(Shard(1), Shard(1), Shard(1)), desired_input_layouts=(Shard(2), Shard(2), Shard(2)), ), - "cross_attn.o": ColwiseParallel(output_layouts=Replicate()), - - "ffn.0": ColwiseParallel(), - "ffn.2": RowwiseParallel(), + "cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False), + + "ffn.0": ColwiseParallel(input_layouts=Shard(1)), + "ffn.2": RowwiseParallel(output_layouts=Replicate()), + + "norm1": SequenceParallel(use_local_output=True), + "norm2": SequenceParallel(use_local_output=True), + "norm3": SequenceParallel(use_local_output=True), + "gate": PrepareModuleInput( + input_layouts=(Shard(1), Replicate(), Replicate()), + desired_input_layouts=(Replicate(), Replicate(), Replicate()), + ) } parallelize_module( module=block, device_mesh=tp_mesh, parallelize_plan=layer_tp_plan, ) - - + + def test_step(self, batch): data = batch[0] data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x @@ -94,9 +119,8 @@ class LitModel(pl.LightningModule): video = self.pipe(**data) if self.local_rank == 0: save_video(video, output_path, fps=15, quality=5) - - - + + if __name__ == "__main__": snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") dataloader = torch.utils.data.DataLoader(