mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
Merge pull request #457 from modelscope/wan-tp
support wan tensor parallel (preview)
This commit is contained in:
@@ -108,6 +108,16 @@ class RMSNorm(nn.Module):
|
|||||||
return self.norm(x.float()).to(dtype) * self.weight
|
return self.norm(x.float()).to(dtype) * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionModule(nn.Module):
|
||||||
|
def __init__(self, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def forward(self, q, k, v):
|
||||||
|
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -122,16 +132,15 @@ class SelfAttention(nn.Module):
|
|||||||
self.norm_q = RMSNorm(dim, eps=eps)
|
self.norm_q = RMSNorm(dim, eps=eps)
|
||||||
self.norm_k = RMSNorm(dim, eps=eps)
|
self.norm_k = RMSNorm(dim, eps=eps)
|
||||||
|
|
||||||
|
self.attn = AttentionModule(self.num_heads)
|
||||||
|
|
||||||
def forward(self, x, freqs):
|
def forward(self, x, freqs):
|
||||||
q = self.norm_q(self.q(x))
|
q = self.norm_q(self.q(x))
|
||||||
k = self.norm_k(self.k(x))
|
k = self.norm_k(self.k(x))
|
||||||
v = self.v(x)
|
v = self.v(x)
|
||||||
x = flash_attention(
|
q = rope_apply(q, freqs, self.num_heads)
|
||||||
q=rope_apply(q, freqs, self.num_heads),
|
k = rope_apply(k, freqs, self.num_heads)
|
||||||
k=rope_apply(k, freqs, self.num_heads),
|
x = self.attn(q, k, v)
|
||||||
v=v,
|
|
||||||
num_heads=self.num_heads
|
|
||||||
)
|
|
||||||
return self.o(x)
|
return self.o(x)
|
||||||
|
|
||||||
|
|
||||||
@@ -154,6 +163,8 @@ class CrossAttention(nn.Module):
|
|||||||
self.v_img = nn.Linear(dim, dim)
|
self.v_img = nn.Linear(dim, dim)
|
||||||
self.norm_k_img = RMSNorm(dim, eps=eps)
|
self.norm_k_img = RMSNorm(dim, eps=eps)
|
||||||
|
|
||||||
|
self.attn = AttentionModule(self.num_heads)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||||
if self.has_image_input:
|
if self.has_image_input:
|
||||||
img = y[:, :257]
|
img = y[:, :257]
|
||||||
@@ -163,7 +174,7 @@ class CrossAttention(nn.Module):
|
|||||||
q = self.norm_q(self.q(x))
|
q = self.norm_q(self.q(x))
|
||||||
k = self.norm_k(self.k(ctx))
|
k = self.norm_k(self.k(ctx))
|
||||||
v = self.v(ctx)
|
v = self.v(ctx)
|
||||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
x = self.attn(q, k, v)
|
||||||
if self.has_image_input:
|
if self.has_image_input:
|
||||||
k_img = self.norm_k_img(self.k_img(img))
|
k_img = self.norm_k_img(self.k_img(img))
|
||||||
v_img = self.v_img(img)
|
v_img = self.v_img(img)
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||||
|
|
||||||
# Initialize noise
|
# Initialize noise
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class FlowMatchScheduler():
|
|||||||
self.linear_timesteps_weights = bsmntw_weighing
|
self.linear_timesteps_weights = bsmntw_weighing
|
||||||
|
|
||||||
|
|
||||||
def step(self, model_output, timestep, sample, to_final=False):
|
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||||
if isinstance(timestep, torch.Tensor):
|
if isinstance(timestep, torch.Tensor):
|
||||||
timestep = timestep.cpu()
|
timestep = timestep.cpu()
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ We present a detailed table here. The model is tested on a single A100.
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||||
|
|
||||||
|
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
|
||||||
|
|
||||||
### Wan-Video-14B-I2V
|
### Wan-Video-14B-I2V
|
||||||
|
|
||||||
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
|
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
|
||||||
|
|||||||
125
examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py
Normal file
125
examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import torch
|
||||||
|
import lightning as pl
|
||||||
|
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput, PrepareModuleOutput
|
||||||
|
from torch.distributed._tensor import Replicate, Shard
|
||||||
|
from torch.distributed.tensor.parallel import parallelize_module
|
||||||
|
from lightning.pytorch.strategies import ModelParallelStrategy
|
||||||
|
from diffsynth import ModelManager, WanVideoPipeline, save_video
|
||||||
|
from tqdm import tqdm
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ToyDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, tasks=[]):
|
||||||
|
self.tasks = tasks
|
||||||
|
|
||||||
|
def __getitem__(self, data_id):
|
||||||
|
return self.tasks[data_id]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.tasks)
|
||||||
|
|
||||||
|
|
||||||
|
class LitModel(pl.LightningModule):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
model_manager = ModelManager(device="cpu")
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||||
|
],
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
self.pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
def configure_model(self):
|
||||||
|
tp_mesh = self.device_mesh["tensor_parallel"]
|
||||||
|
for block_id, block in enumerate(self.pipe.dit.blocks):
|
||||||
|
layer_tp_plan = {
|
||||||
|
"self_attn": PrepareModuleInput(
|
||||||
|
input_layouts=(Replicate(), Replicate()),
|
||||||
|
desired_input_layouts=(Replicate(), Shard(0)),
|
||||||
|
),
|
||||||
|
"self_attn.q": SequenceParallel(),
|
||||||
|
"self_attn.k": SequenceParallel(),
|
||||||
|
"self_attn.v": SequenceParallel(),
|
||||||
|
"self_attn.norm_q": SequenceParallel(),
|
||||||
|
"self_attn.norm_k": SequenceParallel(),
|
||||||
|
"self_attn.attn": PrepareModuleInput(
|
||||||
|
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||||
|
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||||
|
),
|
||||||
|
"self_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
||||||
|
|
||||||
|
"cross_attn": PrepareModuleInput(
|
||||||
|
input_layouts=(Replicate(), Replicate()),
|
||||||
|
desired_input_layouts=(Replicate(), Replicate()),
|
||||||
|
),
|
||||||
|
"cross_attn.q": SequenceParallel(),
|
||||||
|
"cross_attn.k": SequenceParallel(),
|
||||||
|
"cross_attn.v": SequenceParallel(),
|
||||||
|
"cross_attn.norm_q": SequenceParallel(),
|
||||||
|
"cross_attn.norm_k": SequenceParallel(),
|
||||||
|
"cross_attn.attn": PrepareModuleInput(
|
||||||
|
input_layouts=(Shard(1), Shard(1), Shard(1)),
|
||||||
|
desired_input_layouts=(Shard(2), Shard(2), Shard(2)),
|
||||||
|
),
|
||||||
|
"cross_attn.o": ColwiseParallel(output_layouts=Replicate()),
|
||||||
|
|
||||||
|
"ffn.0": ColwiseParallel(),
|
||||||
|
"ffn.2": RowwiseParallel(),
|
||||||
|
}
|
||||||
|
parallelize_module(
|
||||||
|
module=block,
|
||||||
|
device_mesh=tp_mesh,
|
||||||
|
parallelize_plan=layer_tp_plan,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_step(self, batch):
|
||||||
|
data = batch[0]
|
||||||
|
data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x
|
||||||
|
output_path = data.pop("output_path")
|
||||||
|
with torch.no_grad(), torch.inference_mode(False):
|
||||||
|
video = self.pipe(**data)
|
||||||
|
if self.local_rank == 0:
|
||||||
|
save_video(video, output_path, fps=15, quality=5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
ToyDataset([
|
||||||
|
{
|
||||||
|
"prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||||
|
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
"num_inference_steps": 50,
|
||||||
|
"seed": 0,
|
||||||
|
"tiled": False,
|
||||||
|
"output_path": "video1.mp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||||
|
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
"num_inference_steps": 50,
|
||||||
|
"seed": 1,
|
||||||
|
"tiled": False,
|
||||||
|
"output_path": "video2.mp4",
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
collate_fn=lambda x: x
|
||||||
|
)
|
||||||
|
model = LitModel()
|
||||||
|
trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), strategy=ModelParallelStrategy())
|
||||||
|
trainer.test(model, dataloader)
|
||||||
Reference in New Issue
Block a user