mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
update tensor parallel
This commit is contained in:
@@ -183,6 +183,13 @@ class CrossAttention(nn.Module):
|
|||||||
return self.o(x)
|
return self.o(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GateModule(nn.Module):
|
||||||
|
def __init__(self,):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, gate, residual):
|
||||||
|
return x + gate * residual
|
||||||
|
|
||||||
class DiTBlock(nn.Module):
|
class DiTBlock(nn.Module):
|
||||||
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -199,16 +206,17 @@ class DiTBlock(nn.Module):
|
|||||||
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
||||||
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
||||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||||
|
self.gate = GateModule()
|
||||||
|
|
||||||
def forward(self, x, context, t_mod, freqs):
|
def forward(self, x, context, t_mod, freqs):
|
||||||
# msa: multi-head self-attention mlp: multi-layer perceptron
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||||
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
||||||
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||||
x = x + gate_msa * self.self_attn(input_x, freqs)
|
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
||||||
x = x + self.cross_attn(self.norm3(x), context)
|
x = x + self.cross_attn(self.norm3(x), context)
|
||||||
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||||
x = x + gate_mlp * self.ffn(input_x)
|
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -44,11 +44,28 @@ class LitModel(pl.LightningModule):
|
|||||||
|
|
||||||
def configure_model(self):
|
def configure_model(self):
|
||||||
tp_mesh = self.device_mesh["tensor_parallel"]
|
tp_mesh = self.device_mesh["tensor_parallel"]
|
||||||
|
plan = {
|
||||||
|
"text_embedding.0": ColwiseParallel(),
|
||||||
|
"text_embedding.2": RowwiseParallel(),
|
||||||
|
"time_projection.1": ColwiseParallel(output_layouts=Replicate()),
|
||||||
|
"text_embedding.0": ColwiseParallel(),
|
||||||
|
"text_embedding.2": RowwiseParallel(),
|
||||||
|
"blocks.0": PrepareModuleInput(
|
||||||
|
input_layouts=(Replicate(), None, None, None),
|
||||||
|
desired_input_layouts=(Replicate(), None, None, None),
|
||||||
|
),
|
||||||
|
"head": PrepareModuleInput(
|
||||||
|
input_layouts=(Replicate(), None),
|
||||||
|
desired_input_layouts=(Replicate(), None),
|
||||||
|
use_local_output=True,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan)
|
||||||
for block_id, block in enumerate(self.pipe.dit.blocks):
|
for block_id, block in enumerate(self.pipe.dit.blocks):
|
||||||
layer_tp_plan = {
|
layer_tp_plan = {
|
||||||
"self_attn": PrepareModuleInput(
|
"self_attn": PrepareModuleInput(
|
||||||
input_layouts=(Replicate(), Replicate()),
|
input_layouts=(Shard(1), Replicate()),
|
||||||
desired_input_layouts=(Replicate(), Shard(0)),
|
desired_input_layouts=(Shard(1), Shard(0)),
|
||||||
),
|
),
|
||||||
"self_attn.q": SequenceParallel(),
|
"self_attn.q": SequenceParallel(),
|
||||||
"self_attn.k": SequenceParallel(),
|
"self_attn.k": SequenceParallel(),
|
||||||
@@ -59,11 +76,11 @@ class LitModel(pl.LightningModule):
|
|||||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||||
),
|
),
|
||||||
"self_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
"self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()),
|
||||||
|
|
||||||
"cross_attn": PrepareModuleInput(
|
"cross_attn": PrepareModuleInput(
|
||||||
input_layouts=(Replicate(), Replicate()),
|
input_layouts=(Shard(1), Replicate()),
|
||||||
desired_input_layouts=(Replicate(), Replicate()),
|
desired_input_layouts=(Shard(1), Replicate()),
|
||||||
),
|
),
|
||||||
"cross_attn.q": SequenceParallel(),
|
"cross_attn.q": SequenceParallel(),
|
||||||
"cross_attn.k": SequenceParallel(),
|
"cross_attn.k": SequenceParallel(),
|
||||||
@@ -74,18 +91,26 @@ class LitModel(pl.LightningModule):
|
|||||||
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||||
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||||
),
|
),
|
||||||
"cross_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
"cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False),
|
||||||
|
|
||||||
"ffn.0": ColwiseParallel(),
|
"ffn.0": ColwiseParallel(input_layouts=Shard(1)),
|
||||||
"ffn.2": RowwiseParallel(),
|
"ffn.2": RowwiseParallel(output_layouts=Replicate()),
|
||||||
|
|
||||||
|
"norm1": SequenceParallel(use_local_output=True),
|
||||||
|
"norm2": SequenceParallel(use_local_output=True),
|
||||||
|
"norm3": SequenceParallel(use_local_output=True),
|
||||||
|
"gate": PrepareModuleInput(
|
||||||
|
input_layouts=(Shard(1), Replicate(), Replicate()),
|
||||||
|
desired_input_layouts=(Replicate(), Replicate(), Replicate()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
parallelize_module(
|
parallelize_module(
|
||||||
module=block,
|
module=block,
|
||||||
device_mesh=tp_mesh,
|
device_mesh=tp_mesh,
|
||||||
parallelize_plan=layer_tp_plan,
|
parallelize_plan=layer_tp_plan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_step(self, batch):
|
def test_step(self, batch):
|
||||||
data = batch[0]
|
data = batch[0]
|
||||||
data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x
|
data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x
|
||||||
@@ -94,9 +119,8 @@ class LitModel(pl.LightningModule):
|
|||||||
video = self.pipe(**data)
|
video = self.pipe(**data)
|
||||||
if self.local_rank == 0:
|
if self.local_rank == 0:
|
||||||
save_video(video, output_path, fps=15, quality=5)
|
save_video(video, output_path, fps=15, quality=5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
|||||||
Reference in New Issue
Block a user