mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
update tensor parallel
This commit is contained in:
@@ -44,11 +44,28 @@ class LitModel(pl.LightningModule):
|
||||
|
||||
def configure_model(self):
|
||||
tp_mesh = self.device_mesh["tensor_parallel"]
|
||||
plan = {
|
||||
"text_embedding.0": ColwiseParallel(),
|
||||
"text_embedding.2": RowwiseParallel(),
|
||||
"time_projection.1": ColwiseParallel(output_layouts=Replicate()),
|
||||
"text_embedding.0": ColwiseParallel(),
|
||||
"text_embedding.2": RowwiseParallel(),
|
||||
"blocks.0": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), None, None, None),
|
||||
desired_input_layouts=(Replicate(), None, None, None),
|
||||
),
|
||||
"head": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), None),
|
||||
desired_input_layouts=(Replicate(), None),
|
||||
use_local_output=True,
|
||||
)
|
||||
}
|
||||
self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan)
|
||||
for block_id, block in enumerate(self.pipe.dit.blocks):
|
||||
layer_tp_plan = {
|
||||
"self_attn": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), Replicate()),
|
||||
desired_input_layouts=(Replicate(), Shard(0)),
|
||||
input_layouts=(Shard(1), Replicate()),
|
||||
desired_input_layouts=(Shard(1), Shard(0)),
|
||||
),
|
||||
"self_attn.q": SequenceParallel(),
|
||||
"self_attn.k": SequenceParallel(),
|
||||
@@ -59,11 +76,11 @@ class LitModel(pl.LightningModule):
|
||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||
),
|
||||
"self_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
||||
|
||||
"self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()),
|
||||
|
||||
"cross_attn": PrepareModuleInput(
|
||||
input_layouts=(Replicate(), Replicate()),
|
||||
desired_input_layouts=(Replicate(), Replicate()),
|
||||
input_layouts=(Shard(1), Replicate()),
|
||||
desired_input_layouts=(Shard(1), Replicate()),
|
||||
),
|
||||
"cross_attn.q": SequenceParallel(),
|
||||
"cross_attn.k": SequenceParallel(),
|
||||
@@ -74,18 +91,26 @@ class LitModel(pl.LightningModule):
|
||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||
),
|
||||
"cross_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
||||
|
||||
"ffn.0": ColwiseParallel(),
|
||||
"ffn.2": RowwiseParallel(),
|
||||
"cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False),
|
||||
|
||||
"ffn.0": ColwiseParallel(input_layouts=Shard(1)),
|
||||
"ffn.2": RowwiseParallel(output_layouts=Replicate()),
|
||||
|
||||
"norm1": SequenceParallel(use_local_output=True),
|
||||
"norm2": SequenceParallel(use_local_output=True),
|
||||
"norm3": SequenceParallel(use_local_output=True),
|
||||
"gate": PrepareModuleInput(
|
||||
input_layouts=(Shard(1), Replicate(), Replicate()),
|
||||
desired_input_layouts=(Replicate(), Replicate(), Replicate()),
|
||||
)
|
||||
}
|
||||
parallelize_module(
|
||||
module=block,
|
||||
device_mesh=tp_mesh,
|
||||
parallelize_plan=layer_tp_plan,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_step(self, batch):
|
||||
data = batch[0]
|
||||
data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x
|
||||
@@ -94,9 +119,8 @@ class LitModel(pl.LightningModule):
|
||||
video = self.pipe(**data)
|
||||
if self.local_rank == 0:
|
||||
save_video(video, output_path, fps=15, quality=5)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
|
||||
Reference in New Issue
Block a user