diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 32a79e3..da1aafc 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -108,6 +108,16 @@ class RMSNorm(nn.Module): return self.norm(x.float()).to(dtype) * self.weight +class AttentionModule(nn.Module): + def __init__(self, num_heads): + super().__init__() + self.num_heads = num_heads + + def forward(self, q, k, v): + x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads) + return x + + class SelfAttention(nn.Module): def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): super().__init__() @@ -121,17 +131,16 @@ class SelfAttention(nn.Module): self.o = nn.Linear(dim, dim) self.norm_q = RMSNorm(dim, eps=eps) self.norm_k = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) def forward(self, x, freqs): q = self.norm_q(self.q(x)) k = self.norm_k(self.k(x)) v = self.v(x) - x = flash_attention( - q=rope_apply(q, freqs, self.num_heads), - k=rope_apply(k, freqs, self.num_heads), - v=v, - num_heads=self.num_heads - ) + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + x = self.attn(q, k, v) return self.o(x) @@ -153,6 +162,8 @@ class CrossAttention(nn.Module): self.k_img = nn.Linear(dim, dim) self.v_img = nn.Linear(dim, dim) self.norm_k_img = RMSNorm(dim, eps=eps) + + self.attn = AttentionModule(self.num_heads) def forward(self, x: torch.Tensor, y: torch.Tensor): if self.has_image_input: @@ -163,7 +174,7 @@ class CrossAttention(nn.Module): q = self.norm_q(self.q(x)) k = self.norm_k(self.k(ctx)) v = self.v(ctx) - x = flash_attention(q, k, v, num_heads=self.num_heads) + x = self.attn(q, k, v) if self.has_image_input: k_img = self.norm_k_img(self.k_img(img)) v_img = self.v_img(img) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 76e1fa0..439d311 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -225,7 +225,7 @@ class WanVideoPipeline(BasePipeline): tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} # Scheduler - self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift) + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) # Initialize noise noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32) diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index fde8849..d6d0219 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -37,7 +37,7 @@ class FlowMatchScheduler(): self.linear_timesteps_weights = bsmntw_weighing - def step(self, model_output, timestep, sample, to_final=False): + def step(self, model_output, timestep, sample, to_final=False, **kwargs): if isinstance(timestep, torch.Tensor): timestep = timestep.cpu() timestep_id = torch.argmin((self.timesteps - timestep).abs()) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 92f25a3..b3f5ade 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -49,6 +49,8 @@ We present a detailed table here. The model is tested on a single A100. https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f +Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py). + ### Wan-Video-14B-I2V Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py). diff --git a/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py b/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py new file mode 100644 index 0000000..b4f5612 --- /dev/null +++ b/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py @@ -0,0 +1,125 @@ +import torch +import lightning as pl +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput, PrepareModuleOutput +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.tensor.parallel import parallelize_module +from lightning.pytorch.strategies import ModelParallelStrategy +from diffsynth import ModelManager, WanVideoPipeline, save_video +from tqdm import tqdm +from modelscope import snapshot_download + + + +class ToyDataset(torch.utils.data.Dataset): + def __init__(self, tasks=[]): + self.tasks = tasks + + def __getitem__(self, data_id): + return self.tasks[data_id] + + def __len__(self): + return len(self.tasks) + + +class LitModel(pl.LightningModule): + def __init__(self): + super().__init__() + model_manager = ModelManager(device="cpu") + model_manager.load_models( + [ + [ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + ], + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", + ], + torch_dtype=torch.bfloat16, + ) + self.pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") + + def configure_model(self): + tp_mesh = self.device_mesh["tensor_parallel"] + for block_id, block in enumerate(self.pipe.dit.blocks): + layer_tp_plan = { + "self_attn": PrepareModuleInput( + input_layouts=(Replicate(), Replicate()), + desired_input_layouts=(Replicate(), Shard(0)), + ), + "self_attn.q": SequenceParallel(), + "self_attn.k": SequenceParallel(), + "self_attn.v": SequenceParallel(), + "self_attn.norm_q": SequenceParallel(), + "self_attn.norm_k": SequenceParallel(), + "self_attn.attn": PrepareModuleInput( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(2), Shard(2), Shard(2)), + ), + "self_attn.o": ColwiseParallel(output_layouts=Replicate()), + + "cross_attn": PrepareModuleInput( + input_layouts=(Replicate(), Replicate()), + desired_input_layouts=(Replicate(), Replicate()), + ), + "cross_attn.q": SequenceParallel(), + "cross_attn.k": SequenceParallel(), + "cross_attn.v": SequenceParallel(), + "cross_attn.norm_q": SequenceParallel(), + "cross_attn.norm_k": SequenceParallel(), + "cross_attn.attn": PrepareModuleInput( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(2), Shard(2), Shard(2)), + ), + "cross_attn.o": ColwiseParallel(output_layouts=Replicate()), + + "ffn.0": ColwiseParallel(), + "ffn.2": RowwiseParallel(), + } + parallelize_module( + module=block, + device_mesh=tp_mesh, + parallelize_plan=layer_tp_plan, + ) + + + def test_step(self, batch): + data = batch[0] + data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x + output_path = data.pop("output_path") + with torch.no_grad(), torch.inference_mode(False): + video = self.pipe(**data) + if self.local_rank == 0: + save_video(video, output_path, fps=15, quality=5) + + + +if __name__ == "__main__": + snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") + dataloader = torch.utils.data.DataLoader( + ToyDataset([ + { + "prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "num_inference_steps": 50, + "seed": 0, + "tiled": False, + "output_path": "video1.mp4", + }, + { + "prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "num_inference_steps": 50, + "seed": 1, + "tiled": False, + "output_path": "video2.mp4", + }, + ]), + collate_fn=lambda x: x + ) + model = LitModel() + trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), strategy=ModelParallelStrategy()) + trainer.test(model, dataloader) \ No newline at end of file