wans2v usp

This commit is contained in:
mi804
2025-08-27 19:50:33 +08:00
parent 4147473c81
commit fdeb363fa2
3 changed files with 42 additions and 20 deletions

View File

@@ -459,10 +459,13 @@ class WanS2VModel(torch.nn.Module):
)
return x, rope_embs, mask_input
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len):
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False):
if block_idx in self.audio_injector.injected_block_id.keys():
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
num_frames = audio_emb.shape[1]
if use_unified_sequence_parallel:
from xfuser.core.distributed import get_sp_group
hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
@@ -476,7 +479,9 @@ class WanS2VModel(torch.nn.Module):
residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out
if use_unified_sequence_parallel:
from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
return hidden_states
def cal_audio_emb(self, audio_input, motion_frames=[73, 19]):

View File

@@ -1284,11 +1284,11 @@ def model_fn_wans2v(
# x and pose_cond
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
seq_len_x = x.shape[1]
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel
# reference image
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))
x = torch.cat([x, ref_latents], dim=1)
# mask
@@ -1305,6 +1305,14 @@ def model_fn_wans2v(
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank()
assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}"
x = torch.chunk(x, world_size, dim=1)[sp_rank]
seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy())
seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)]
seq_len_x = seq_len_x_list[sp_rank]
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
@@ -1315,7 +1323,7 @@ def model_fn_wans2v(
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, seq_len_x, pre_compute_freqs,
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
@@ -1326,7 +1334,7 @@ def model_fn_wans2v(
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, seq_len_x, pre_compute_freqs,
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
@@ -1335,10 +1343,13 @@ def model_fn_wans2v(
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs)
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
x = x[:, :seq_len_x]
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = x[:, :seq_len_x_global]
x = dit.head(x, t[:-1])
x = dit.unpatchify(x, (f, h, w))
# make compatible with wan video

View File

@@ -1,8 +1,9 @@
import torch
from PIL import Image
import librosa
from diffsynth import save_video, VideoData, save_video_with_audio
from diffsynth import VideoData, save_video_with_audio
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
@@ -15,21 +16,28 @@ pipe = WanVideoPipeline.from_pretrained(
],
audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"),
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_video_dataset",
local_dir="./data/example_video_dataset",
allow_file_pattern=f"wans2v/*"
)
num_frames = 81 # 4n+1
height = 448
width = 832
prompt = "a person is singing"
input_image = Image.open("/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/pose.png").convert("RGB").resize((width, height))
negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height))
# s2v audio input, recommend 16kHz sampling rate
audio_path = '/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/sing.MP3'
audio_path = 'data/example_video_dataset/wans2v/sing.MP3'
input_audio, sample_rate = librosa.load(audio_path, sr=16000)
# Speech-to-video
video = pipe(
prompt=prompt,
input_image=input_image,
negative_prompt="",
negative_prompt=negative_prompt,
seed=0,
num_frames=num_frames,
height=height,
@@ -38,18 +46,17 @@ video = pipe(
input_audio=input_audio,
num_inference_steps=40,
)
save_video_with_audio(video, "video_with_audio.mp4", audio_path, fps=16, quality=5)
save_video_with_audio(video[1:], "video_with_audio.mp4", audio_path, fps=16, quality=5)
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
pose_video_path = '/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/pose.mp4'
pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4'
pose_video = VideoData(pose_video_path, height=height, width=width)
pose_video.set_length(num_frames)
# Speech-to-video with pose
video = pipe(
prompt=prompt,
input_image=input_image,
negative_prompt="",
negative_prompt=negative_prompt,
seed=0,
num_frames=num_frames,
height=height,
@@ -59,5 +66,4 @@ video = pipe(
s2v_pose_video=pose_video,
num_inference_steps=40,
)
save_video_with_audio(video, "video_pose_with_audio.mp4", audio_path, fps=16, quality=5)
save_video(pose_video, "video_pose_input.mp4", fps=16, quality=5)
save_video_with_audio(video[1:], "video_pose_with_audio.mp4", audio_path, fps=16, quality=5)