From fdeb363fa2a9f1acda13b47df54da30757fd5d05 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 27 Aug 2025 19:50:33 +0800 Subject: [PATCH] wans2v usp --- diffsynth/models/wan_video_dit_s2v.py | 9 +++++-- diffsynth/pipelines/wan_video_new.py | 27 +++++++++++++------ .../model_inference/Wan2.1-S2V-14B.py | 26 +++++++++++------- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index b0016df..75b19a4 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -459,10 +459,13 @@ class WanS2VModel(torch.nn.Module): ) return x, rope_embs, mask_input - def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len): + def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False): if block_idx in self.audio_injector.injected_block_id.keys(): audio_attn_id = self.audio_injector.injected_block_id[block_idx] num_frames = audio_emb.shape[1] + if use_unified_sequence_parallel: + from xfuser.core.distributed import get_sp_group + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames) @@ -476,7 +479,9 @@ class WanS2VModel(torch.nn.Module): residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb) residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames) hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out - + if use_unified_sequence_parallel: + from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] return hidden_states def cal_audio_emb(self, audio_input, motion_frames=[73, 19]): diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 1362d09..cef7ae8 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1284,11 +1284,11 @@ def model_fn_wans2v( # x and pose_cond pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond - x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) - seq_len_x = x.shape[1] + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) + seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel # reference image - ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) + ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) x = torch.cat([x, ref_latents], dim=1) # mask @@ -1305,6 +1305,14 @@ def model_fn_wans2v( t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() + assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}" + x = torch.chunk(x, world_size, dim=1)[sp_rank] + seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy()) + seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)] + seq_len_x = seq_len_x_list[sp_rank] + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) @@ -1315,7 +1323,7 @@ def model_fn_wans2v( with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs, + x, context, t_mod, seq_len_x, pre_compute_freqs[0], use_reentrant=False, ) x = torch.utils.checkpoint.checkpoint( @@ -1326,7 +1334,7 @@ def model_fn_wans2v( elif use_gradient_checkpointing: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs, + x, context, t_mod, seq_len_x, pre_compute_freqs[0], use_reentrant=False, ) x = torch.utils.checkpoint.checkpoint( @@ -1335,10 +1343,13 @@ def model_fn_wans2v( use_reentrant=False, ) else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs) - x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) - x = x[:, :seq_len_x] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + + x = x[:, :seq_len_x_global] x = dit.head(x, t[:-1]) x = dit.unpatchify(x, (f, h, w)) # make compatible with wan video diff --git a/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py b/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py index 73d4a49..bb93871 100644 --- a/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py +++ b/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py @@ -1,8 +1,9 @@ import torch from PIL import Image import librosa -from diffsynth import save_video, VideoData, save_video_with_audio +from diffsynth import VideoData, save_video_with_audio from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, @@ -15,21 +16,28 @@ pipe = WanVideoPipeline.from_pretrained( ], audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), ) +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_video_dataset", + local_dir="./data/example_video_dataset", + allow_file_pattern=f"wans2v/*" +) + num_frames = 81 # 4n+1 height = 448 width = 832 prompt = "a person is singing" -input_image = Image.open("/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/pose.png").convert("RGB").resize((width, height)) +negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) # s2v audio input, recommend 16kHz sampling rate -audio_path = '/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/sing.MP3' +audio_path = 'data/example_video_dataset/wans2v/sing.MP3' input_audio, sample_rate = librosa.load(audio_path, sr=16000) # Speech-to-video video = pipe( prompt=prompt, input_image=input_image, - negative_prompt="", + negative_prompt=negative_prompt, seed=0, num_frames=num_frames, height=height, @@ -38,18 +46,17 @@ video = pipe( input_audio=input_audio, num_inference_steps=40, ) -save_video_with_audio(video, "video_with_audio.mp4", audio_path, fps=16, quality=5) +save_video_with_audio(video[1:], "video_with_audio.mp4", audio_path, fps=16, quality=5) # s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. -pose_video_path = '/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/pose.mp4' +pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4' pose_video = VideoData(pose_video_path, height=height, width=width) -pose_video.set_length(num_frames) # Speech-to-video with pose video = pipe( prompt=prompt, input_image=input_image, - negative_prompt="", + negative_prompt=negative_prompt, seed=0, num_frames=num_frames, height=height, @@ -59,5 +66,4 @@ video = pipe( s2v_pose_video=pose_video, num_inference_steps=40, ) -save_video_with_audio(video, "video_pose_with_audio.mp4", audio_path, fps=16, quality=5) -save_video(pose_video, "video_pose_input.mp4", fps=16, quality=5) +save_video_with_audio(video[1:], "video_pose_with_audio.mp4", audio_path, fps=16, quality=5)