From b541b9bed27918ada53f2f5f58c9930b22101327 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 27 Aug 2025 11:51:56 +0800 Subject: [PATCH 1/6] wans2v inference --- diffsynth/configs/model_config.py | 4 + diffsynth/data/__init__.py | 2 +- diffsynth/data/video.py | 69 +++ diffsynth/models/wan_video_dit_s2v.py | 579 ++++++++++++++++++ diffsynth/models/wav2vec.py | 199 ++++++ diffsynth/pipelines/wan_video_new.py | 222 ++++++- .../model_inference/Wan2.1-S2V-14B.py | 63 ++ 7 files changed, 1134 insertions(+), 4 deletions(-) create mode 100644 diffsynth/models/wan_video_dit_s2v.py create mode 100644 diffsynth/models/wav2vec.py create mode 100644 examples/wanvideo/model_inference/Wan2.1-S2V-14B.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index b4b847f..43fe84b 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -56,11 +56,13 @@ from ..models.stepvideo_vae import StepVideoVAE from ..models.stepvideo_dit import StepVideoModel from ..models.wan_video_dit import WanModel +from ..models.wan_video_dit_s2v import WanS2VModel from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38 from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_vace import VaceWanModel +from ..models.wav2vec import WanS2VAudioEncoder from ..models.step1x_connector import Qwen2Connector @@ -155,6 +157,7 @@ model_loader_configs = [ (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), + (None, "966cffdcc52f9c46c391768b27637614", ["wan_video_dit"], [WanS2VModel], "civitai"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), @@ -172,6 +175,7 @@ model_loader_configs = [ (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"), (None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"), (None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"), + (None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/data/__init__.py b/diffsynth/data/__init__.py index de09a29..749c03f 100644 --- a/diffsynth/data/__init__.py +++ b/diffsynth/data/__init__.py @@ -1 +1 @@ -from .video import VideoData, save_video, save_frames +from .video import VideoData, save_video, save_frames, merge_video_audio, save_video_with_audio diff --git a/diffsynth/data/video.py b/diffsynth/data/video.py index 8eafa66..c6b9daa 100644 --- a/diffsynth/data/video.py +++ b/diffsynth/data/video.py @@ -2,6 +2,8 @@ import imageio, os import numpy as np from PIL import Image from tqdm import tqdm +import subprocess +import shutil class LowMemoryVideo: @@ -146,3 +148,70 @@ def save_frames(frames, save_path): os.makedirs(save_path, exist_ok=True) for i, frame in enumerate(tqdm(frames, desc="Saving images")): frame.save(os.path.join(save_path, f"{i}.png")) + + +def merge_video_audio(video_path: str, audio_path: str): + # TODO: may need a in-python implementation to avoid subprocess dependency + """ + Merge the video and audio into a new video, with the duration set to the shorter of the two, + and overwrite the original video file. + + Parameters: + video_path (str): Path to the original video file + audio_path (str): Path to the audio file + """ + + # check + if not os.path.exists(video_path): + raise FileNotFoundError(f"video file {video_path} does not exist") + if not os.path.exists(audio_path): + raise FileNotFoundError(f"audio file {audio_path} does not exist") + + base, ext = os.path.splitext(video_path) + temp_output = f"{base}_temp{ext}" + + try: + # create ffmpeg command + command = [ + 'ffmpeg', + '-y', # overwrite + '-i', + video_path, + '-i', + audio_path, + '-c:v', + 'copy', # copy video stream + '-c:a', + 'aac', # use AAC audio encoder + '-b:a', + '192k', # set audio bitrate (optional) + '-map', + '0:v:0', # select the first video stream + '-map', + '1:a:0', # select the first audio stream + '-shortest', # choose the shortest duration + temp_output + ] + + # execute the command + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # check result + if result.returncode != 0: + error_msg = f"FFmpeg execute failed: {result.stderr}" + print(error_msg) + raise RuntimeError(error_msg) + + shutil.move(temp_output, video_path) + print(f"Merge completed, saved to {video_path}") + + except Exception as e: + if os.path.exists(temp_output): + os.remove(temp_output) + print(f"merge_video_audio failed with error: {e}") + + +def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None): + save_video(frames, save_path, fps, quality, ffmpeg_params) + merge_video_audio(save_path, audio_path) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py new file mode 100644 index 0000000..3121c98 --- /dev/null +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -0,0 +1,579 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple +from .utils import hash_state_dict_keys +from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, SelfAttention, Head, CrossAttention + + +def torch_dfs(model: nn.Module, parent_name='root'): + module_names, modules = [], [] + current_name = parent_name if parent_name else 'root' + module_names.append(current_name) + modules.append(model) + + for name, child in model.named_children(): + if parent_name: + child_name = f'{parent_name}.{name}' + else: + child_name = name + child_modules, child_names = torch_dfs(child, child_name) + module_names += child_names + modules += child_modules + return modules, module_names + + +def rope_apply(x, freqs): + n, c = x.size(2), x.size(3) // 2 + # loop over samples + output = [] + for i, _ in enumerate(x): + s = x.size(1) + x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) + freqs_i = freqs[i, :s] + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, s:]]) + # append to collection + output.append(x_i) + return torch.stack(output).to(x.dtype) + + +def rope_precompute(x, grid_sizes, freqs, start=None): + b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 + + # split freqs + if type(freqs) is list: + trainable_freqs = freqs[1] + freqs = freqs[0] + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64)) + seq_bucket = [0] + if not type(grid_sizes) is list: + grid_sizes = [grid_sizes] + for g in grid_sizes: + if not type(g) is list: + g = [torch.zeros_like(g), g] + batch_size = g[0].shape[0] + for i in range(batch_size): + if start is None: + f_o, h_o, w_o = g[0][i] + else: + f_o, h_o, w_o = start[i] + + f, h, w = g[1][i] + t_f, t_h, t_w = g[2][i] + seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o + seq_len = int(seq_f * seq_h * seq_w) + if seq_len > 0: + if t_f > 0: + factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item() + # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item()) + if f_o >= 0: + f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist() + else: + f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist() + h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist() + w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist() + + assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 + freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj() + freqs_0 = freqs_0.view(seq_f, 1, 1, -1) + + freqs_i = torch.cat( + [ + freqs_0.expand(seq_f, seq_h, seq_w, -1), + freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1), + freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1), + ], + dim=-1 + ).reshape(seq_len, 1, -1) + elif t_f < 0: + freqs_i = trainable_freqs.unsqueeze(1) + # apply rotary embedding + output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i + seq_bucket.append(seq_bucket[-1] + seq_len) + return output + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + +class MotionEncoder_tc(nn.Module): + + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.need_global = need_global + self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1) + if need_global: + self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) + self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) + + if need_global: + self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) + + self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + x = rearrange(x, 'b t c -> b c t') + x_ori = x.clone() + b, c, t = x.shape + x = self.conv1_local(x) + x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + if not self.need_global: + return x_local + + x = self.conv1_global(x_ori) + x = rearrange(x, 'b c t -> b t c') + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv2(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, 'b t c -> b c t') + x = self.conv3(x) + x = rearrange(x, 'b c t -> b t c') + x = self.norm3(x) + x = self.act(x) + x = self.final_linear(x) + x = rearrange(x, '(b n) t c -> b t n c', b=b) + + return x, x_local + + +class FramePackMotioner(nn.Module): + + def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode="drop", *args, **kwargs): + super().__init__(*args, **kwargs) + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long) + + self.inner_dim = inner_dim + self.num_heads = num_heads + self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1) + self.drop_mode = drop_mode + + def forward(self, motion_latents, add_last_motion=2): + motion_frames = motion_latents[0].shape[1] + mot = [] + mot_remb = [] + for m in motion_latents: + lat_height, lat_width = m.shape[2], m.shape[3] + padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype) + overlap_frame = min(padd_lat.shape[1], m.shape[1]) + if overlap_frame > 0: + padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] + + if add_last_motion < 2 and self.drop_mode != "drop": + zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum() + padd_lat[:, -zero_end_frame:] = 0 + + padd_lat = padd_lat.unsqueeze(0) + clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split( + list(self.zip_frame_buckets)[::-1], dim=2 + ) # 16, 2 ,1 + + # patchfy + clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2) + clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2) + clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2) + + if add_last_motion < 2 and self.drop_mode == "drop": + clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post + clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x + + motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) + + # rope + start_time_id = -(self.zip_frame_buckets[:1].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[0] + grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:2].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 + grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \ + [ + [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] + ] + + start_time_id = -(self.zip_frame_buckets[:3].sum()) + end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 + grid_sizes_4x = [ + [ + torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), + torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1), + torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), + ] + ] + + grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x + + motion_rope_emb = rope_precompute( + motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads), + grid_sizes, + self.freqs, + start=None + ) + + mot.append(motion_lat) + mot_remb.append(motion_rope_emb) + return mot, mot_remb + + +class AdaLayerNorm(nn.Module): + + def __init__( + self, + embedding_dim: int, + output_dim: int, + norm_eps: float = 1e-5, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, output_dim) + self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False) + + def forward(self, x, temb): + temb = self.linear(F.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + x = self.norm(x) * (1 + scale) + shift + return x + + +class AudioInjector_WAN(nn.Module): + + def __init__( + self, + all_modules, + all_modules_names, + dim=2048, + num_heads=32, + inject_layer=[0, 27], + enable_adain=False, + adain_dim=2048, + ): + super().__init__() + self.injected_block_id = {} + audio_injector_id = 0 + for mod_name, mod in zip(all_modules_names, all_modules): + if isinstance(mod, DiTBlock): + for inject_id in inject_layer: + if f'transformer_blocks.{inject_id}' in mod_name: + self.injected_block_id[inject_id] = audio_injector_id + audio_injector_id += 1 + + self.injector = nn.ModuleList([CrossAttention( + dim=dim, + num_heads=num_heads, + ) for _ in range(audio_injector_id)]) + self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id)]) + self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm( + dim, + elementwise_affine=False, + eps=1e-6, + ) for _ in range(audio_injector_id)]) + if enable_adain: + self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)]) + + +class CausalAudioEncoder(nn.Module): + + def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False): + super().__init__() + self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global) + weight = torch.ones((1, num_layers, 1, 1)) * 0.01 + + self.weights = torch.nn.Parameter(weight) + self.act = torch.nn.SiLU() + + def forward(self, features): + # features B * num_layers * dim * video_length + weights = self.act(self.weights.to(device=features.device, dtype=features.dtype)) + weights_sum = weights.sum(dim=1, keepdims=True) + weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f + weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim + res = self.encoder(weighted_feat) # b f n dim + return res # b f n dim + + +class WanS2VSelfAttention(SelfAttention): + + def forward(self, x, freqs): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x) + q = rope_apply(q, freqs) + k = rope_apply(k, freqs) + x = self.attn(q.view(b, s, n * d), k.view(b, s, n * d), v) + return self.o(x) + + +class WanS2VDiTBlock(DiTBlock): + + def __init__(self, dim, num_heads, ffn_dim, eps=1e-6, has_image_input=False): + super().__init__(has_image_input=has_image_input, dim=dim, num_heads=num_heads, ffn_dim=ffn_dim, eps=eps) + self.self_attn = WanS2VSelfAttention(dim, num_heads, eps) + + def forward(self, x, context, e, freqs): + seg_idx = e[1].item() + seg_idx = min(max(0, seg_idx), x.size(1)) + seg_idx = [0, seg_idx, x.size(1)] + e = e[0] + modulation = self.modulation.unsqueeze(2).to(dtype=e.dtype, device=e.device) + e = (modulation + e).chunk(6, dim=1) + e = [element.squeeze(1) for element in e] + norm_x = self.norm1(x) + parts = [] + for i in range(2): + parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1]) + norm_x = torch.cat(parts, dim=1) + # self-attention + y = self.self_attn(norm_x, freqs) + z = [] + for i in range(2): + z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1]) + y = torch.cat(z, dim=1) + x = x + y + + # cross-attention & ffn function + def cross_attn_ffn(x, context, e): + x = x + self.cross_attn(self.norm3(x), context) + norm2_x = self.norm2(x) + parts = [] + for i in range(2): + parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1]) + norm2_x = torch.cat(parts, dim=1) + y = self.ffn(norm2_x) + z = [] + for i in range(2): + z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1]) + y = torch.cat(z, dim=1) + x = x + y + return x + + x = cross_attn_ffn(x, context, e) + return x + + +class S2VHead(Head): + + def forward(self, x, t_mod): + t_mod = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + t_mod[1]) + t_mod[0])) + return x + + +class WanS2VModel(torch.nn.Module): + + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + cond_dim: int, + audio_dim: int, + num_audio_token: int, + enable_adain: bool = True, + audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39], + zero_timestep: bool = True, + add_last_motion: bool = True, + framepack_drop_mode: str = "padd", + fuse_vae_embedding_in_latents: bool = True, + require_vae_embedding: bool = False, + seperated_timestep: bool = False, + require_clip_embedding: bool = False, + ): + super().__init__() + self.dim = dim + self.in_dim = in_dim + self.freq_dim = freq_dim + self.patch_size = patch_size + self.num_heads = num_heads + self.enbale_adain = enable_adain + self.add_last_motion = add_last_motion + self.zero_timestep = zero_timestep + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents + self.require_vae_embedding = require_vae_embedding + self.seperated_timestep = seperated_timestep + self.require_clip_embedding = require_clip_embedding + + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim)) + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + self.blocks = nn.ModuleList([WanS2VDiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) + self.head = S2VHead(dim, out_dim, patch_size, eps) + self.freqs = precompute_freqs_cis_3d(dim // num_heads) + + self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size) + self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain) + all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks") + # TODO: refactor dfs + self.audio_injector = AudioInjector_WAN( + all_modules, + all_modules_names, + dim=dim, + num_heads=num_heads, + inject_layer=audio_inject_layers, + enable_adain=enable_adain, + adain_dim=dim, + ) + self.trainable_cond_mask = nn.Embedding(3, dim) + self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode) + + def patchify(self, x: torch.Tensor): + grid_size = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + return x, grid_size # x, grid_size: (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, + 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], + h=grid_size[1], + w=grid_size[2], + x=self.patch_size[0], + y=self.patch_size[1], + z=self.patch_size[2] + ) + + def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2): + flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion) + if drop_motion_frames: + return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb] + else: + return flattern_mot, mot_remb + + def inject_motion(self, x, seq_lens, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2): + # inject the motion frames token to the hidden states + mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion) + if len(mot) > 0: + x = torch.cat([x, mot[0]], dim=1) + seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long) + rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1) + mask_input = torch.cat( + [mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1 + ) + return x, seq_lens, rope_embs, mask_input + + def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len): + if block_idx in self.audio_injector.injected_block_id.keys(): + audio_attn_id = self.audio_injector.injected_block_id[block_idx] + num_frames = audio_emb.shape[1] + + input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c + input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames) + + audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c") + adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0]) + attn_hidden_states = adain_hidden_states + + audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames) + attn_audio_emb = audio_emb + residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb) + residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames) + hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out + + return hidden_states + + @staticmethod + def state_dict_converter(): + return WanS2VModelStateDictConverter() + + +class WanS2VModelStateDictConverter: + + def __init__(self): + pass + + def from_civitai(self, state_dict): + config = {} + if hash_state_dict_keys(state_dict) == "966cffdcc52f9c46c391768b27637614": + config = { + "dim": 5120, + "in_dim": 16, + "ffn_dim": 13824, + "out_dim": 16, + "text_dim": 4096, + "freq_dim": 256, + "eps": 1e-06, + "patch_size": (1, 2, 2), + "num_heads": 40, + "num_layers": 40, + "cond_dim": 16, + "audio_dim": 1024, + "num_audio_token": 4, + } + return state_dict, config diff --git a/diffsynth/models/wav2vec.py b/diffsynth/models/wav2vec.py new file mode 100644 index 0000000..f07e99e --- /dev/null +++ b/diffsynth/models/wav2vec.py @@ -0,0 +1,199 @@ +import math +import numpy as np +import torch +import torch.nn.functional as F + + +def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None): + required_duration = num_sample / target_fps + required_origin_frames = int(np.ceil(required_duration * original_fps)) + if required_duration > total_frames / original_fps: + raise ValueError("required_duration must be less than video length") + + if not fixed_start is None and fixed_start >= 0: + start_frame = fixed_start + else: + max_start = total_frames - required_origin_frames + if max_start < 0: + raise ValueError("video length is too short") + start_frame = np.random.randint(0, max_start + 1) + start_time = start_frame / original_fps + + end_time = start_time + required_duration + time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) + + frame_indices = np.round(np.array(time_points) * original_fps).astype(int) + frame_indices = np.clip(frame_indices, 0, total_frames - 1) + return frame_indices + + +def linear_interpolation(features, input_fps, output_fps, output_len=None): + """ + features: shape=[1, T, 512] + input_fps: fps for audio, f_a + output_fps: fps for video, f_m + output_len: video length + """ + features = features.transpose(1, 2) + seq_len = features.shape[2] / float(input_fps) + if output_len is None: + output_len = int(seq_len * output_fps) + output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len] + return output_features.transpose(1, 2) + + +class WanS2VAudioEncoder(torch.nn.Module): + + def __init__(self): + super().__init__() + from transformers import Wav2Vec2ForCTC, Wav2Vec2Config + config = { + "_name_or_path": "facebook/wav2vec2-large-xlsr-53", + "activation_dropout": 0.05, + "apply_spec_augment": True, + "architectures": ["Wav2Vec2ForCTC"], + "attention_dropout": 0.1, + "bos_token_id": 1, + "conv_bias": True, + "conv_dim": [512, 512, 512, 512, 512, 512, 512], + "conv_kernel": [10, 3, 3, 3, 3, 2, 2], + "conv_stride": [5, 2, 2, 2, 2, 2, 2], + "ctc_loss_reduction": "mean", + "ctc_zero_infinity": True, + "do_stable_layer_norm": True, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.05, + "final_dropout": 0.0, + "hidden_act": "gelu", + "hidden_dropout": 0.05, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.05, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_feature_length": 10, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_space": 1, + "mask_time_other": 0.0, + "mask_time_prob": 0.05, + "mask_time_selection": "static", + "model_type": "wav2vec2", + "num_attention_heads": 16, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "pad_token_id": 0, + "transformers_version": "4.7.0.dev0", + "vocab_size": 33 + } + self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config)) + self.video_rate = 30 + + def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32): + input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(self.model.dtype) + + # retrieve logits & take argmax + res = self.model(input_values.to(self.model.device), output_hidden_states=True) + if return_all_layers: + feat = torch.cat(res.hidden_states) + else: + feat = res.hidden_states[-1] + feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) + + z = feat.to(dtype) + return z + + def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 + + bucket_num = min_batch_num * batch_frames + batch_idx = [stride * i for i in range(bucket_num)] + batch_audio_eb = [] + for bi in batch_idx: + if bi < audio_frame_num: + audio_sample_stride = 2 + chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = \ + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): + num_layers, audio_frame_num, audio_dim = audio_embed.shape + + if num_layers > 1: + return_all_layers = True + else: + return_all_layers = False + + scale = self.video_rate / fps + + min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 + + bucket_num = min_batch_num * batch_frames + padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num + batch_idx = get_sample_indices( + original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0 + ) + batch_audio_eb = [] + audio_sample_stride = int(self.video_rate / fps) + for bi in batch_idx: + if bi < audio_frame_num: + + chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) + chosen_idx = [0 if c < 0 else c for c in chosen_idx] + chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] + + if return_all_layers: + frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) + else: + frame_audio_embed = audio_embed[0][chosen_idx].flatten() + else: + frame_audio_embed = \ + torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ + else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) + batch_audio_eb.append(frame_audio_embed) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) + + return batch_audio_eb, min_batch_num + + @staticmethod + def state_dict_converter(): + return WanS2VAudioEncoderStateDictConverter() + + +class WanS2VAudioEncoderStateDictConverter(): + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict = {'model.' + k: v for k, v in state_dict.items()} + return state_dict diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 53df7d9..ff3c4bd 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -15,6 +15,7 @@ from typing_extensions import Literal from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner from ..models import ModelManager, load_state_dict from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d +from ..models.wan_video_dit_s2v import rope_precompute from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample from ..models.wan_video_image_encoder import WanImageEncoder @@ -49,6 +50,7 @@ class WanVideoPipeline(BasePipeline): self.units = [ WanVideoUnit_ShapeChecker(), WanVideoUnit_NoiseInitializer(), + WanVideoUnit_S2V(), WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedderVAE(), @@ -127,6 +129,8 @@ class WanVideoPipeline(BasePipeline): torch.nn.LayerNorm: WanAutoCastLayerNorm, RMSNorm: AutoWrappedModule, torch.nn.Conv2d: AutoWrappedModule, + torch.nn.Conv1d: AutoWrappedModule, + torch.nn.Embedding: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, @@ -254,6 +258,24 @@ class WanVideoPipeline(BasePipeline): ), vram_limit=vram_limit, ) + if self.audio_encoder is not None: + # TODO: need check + dtype = next(iter(self.audio_encoder.parameters())).dtype + enable_vram_management( + self.audio_encoder, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + torch.nn.LayerNorm: AutoWrappedModule, + }, + module_config = dict( + offload_dtype=dtype, + offload_device="cpu", + onload_dtype=dtype, + onload_device="cpu", + computation_dtype=dtype, + computation_device=self.device, + ), + ) def initialize_usp(self): @@ -290,6 +312,7 @@ class WanVideoPipeline(BasePipeline): device: Union[str, torch.device] = "cuda", model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), + audio_processor_config: ModelConfig = None, redirect_common_files: bool = True, use_usp=False, ): @@ -332,7 +355,8 @@ class WanVideoPipeline(BasePipeline): pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") pipe.vace = model_manager.fetch_model("wan_video_vace") - + pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder") + # Size division factor if pipe.vae is not None: pipe.height_division_factor = pipe.vae.upsampling_factor * 2 @@ -342,7 +366,11 @@ class WanVideoPipeline(BasePipeline): tokenizer_config.download_if_necessary(use_usp=use_usp) pipe.prompter.fetch_models(pipe.text_encoder) pipe.prompter.fetch_tokenizer(tokenizer_config.path) - + + if audio_processor_config is not None: + audio_processor_config.download_if_necessary(use_usp=use_usp) + from transformers import Wav2Vec2Processor + pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path) # Unified Sequence Parallel if use_usp: pipe.enable_usp() return pipe @@ -361,6 +389,10 @@ class WanVideoPipeline(BasePipeline): # Video-to-video input_video: Optional[list[Image.Image]] = None, denoising_strength: Optional[float] = 1.0, + # Speech-to-video + input_audio: Optional[str] = None, + audio_sample_rate: Optional[int] = 16000, + s2v_pose_video: Optional[list[Image.Image]] = None, # ControlNet control_video: Optional[list[Image.Image]] = None, reference_image: Optional[Image.Image] = None, @@ -429,6 +461,7 @@ class WanVideoPipeline(BasePipeline): "motion_bucket_id": motion_bucket_id, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, + "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -868,6 +901,67 @@ class WanVideoUnit_CfgMerger(PipelineUnit): return inputs_shared, inputs_posi, inputs_nega +class WanVideoUnit_S2V(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("audio_encoder", "vae", ) + ) + + def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames): + if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None: + return {} + pipe.load_models_to_device(["audio_encoder"]) + z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True) + audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps( + z, fps=16, batch_frames=num_frames - 1, m=0 + ) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0).to(pipe.device, pipe.torch_dtype) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + audio_embed_bucket = audio_embed_bucket[..., 0:num_frames-1] + return {"audio_input": audio_embed_bucket} + + def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride): + pipe.load_models_to_device(["vae"]) + # TODO: may support input motion latents + motion_frames = 73 + motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) + lat_motion_frames = (motion_frames + 3) // 4 + motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"motion_latents": motion_latents, "motion_frames": [motion_frames, lat_motion_frames]} + + def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride): + pipe.load_models_to_device(["vae"]) + if s2v_pose_video is None: + input_video = -torch.ones(1, 3, num_frames, height, width, device=pipe.device, dtype=pipe.torch_dtype) + else: + input_video = pipe.preprocess_video(s2v_pose_video) + # get num_frames-1 frames + input_video = input_video[:, :, :num_frames] + # pad if not enough frames + padding_frames = num_frames - input_video.shape[2] + input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) + # encode to latents + input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"pose_cond": input_latents[:,:,1:]} + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("input_audio") is None or pipe.audio_encoder is None or pipe.audio_processor is None: + return inputs_shared, inputs_posi, inputs_nega + input_audio, audio_sample_rate, s2v_pose_video, num_frames, height, width = inputs_shared.get("input_audio"), inputs_shared.get("audio_sample_rate"), inputs_shared.get("s2v_pose_video"), inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width") + tiled, tile_size, tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride") + + audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames) + inputs_posi.update(audio_input_positive) + inputs_nega.update({"audio_input": 0.0 * audio_input_positive["audio_input"]}) + + inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride)) + inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride)) + return inputs_shared, inputs_posi, inputs_nega + class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh, model_id): @@ -987,6 +1081,10 @@ def model_fn_wan_video( reference_latents = None, vace_context = None, vace_scale = 1.0, + audio_input: Optional[torch.Tensor] = None, + motion_latents: Optional[torch.Tensor] = None, + motion_frames: Optional[list] = None, + pose_cond: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, motion_bucket_id: Optional[torch.Tensor] = None, @@ -1024,7 +1122,21 @@ def model_fn_wan_video( tensor_names=["latents", "y"], batch_size=2 if cfg_merge else 1 ) - + # wan2.2 s2v + if audio_input is not None: + return model_fn_wans2v( + dit=dit, + latents=latents, + timestep=timestep, + context=context, + audio_input=audio_input, + motion_latents=motion_latents, + motion_frames=motion_frames, + pose_cond=pose_cond, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + use_gradient_checkpointing=use_gradient_checkpointing, + ) + if use_unified_sequence_parallel: import torch.distributed as dist from xfuser.core.distributed import (get_sequence_parallel_rank, @@ -1143,3 +1255,107 @@ def model_fn_wan_video( f -= 1 x = dit.unpatchify(x, (f, h, w)) return x + + +def model_fn_wans2v( + dit, + latents, + timestep, + context, + audio_input, + motion_latents, + motion_frames, + pose_cond, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False +): + origin_ref_latents = latents[:, :, 0:1] + latents = latents[:, :, 1:] + + audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1) + audio_emb_global, audio_emb = dit.casual_audio_encoder(audio_input) + audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone() + merged_audio_emb = audio_emb[:, motion_frames[1]:, :] + + # reference image + x = latents + pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) + + grid_sizes = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0) + seq_lens = torch.tensor([x.size(1)], dtype=torch.long) + grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]] + + ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) + + ref_grid_sizes = [[ + torch.tensor([30, 0, 0]).unsqueeze(0), + torch.tensor([31, rh, rw]).unsqueeze(0), + torch.tensor([1, rh, rw]).unsqueeze(0), + ]] + original_seq_len = seq_lens[0] + seq_lens = seq_lens + torch.tensor([ref_latents.shape[1]], dtype=torch.long) + grid_sizes = grid_sizes + ref_grid_sizes + + x = torch.cat([x, ref_latents], dim=1) + mask = torch.zeros([1, x.shape[1]], dtype=torch.long, device=x.device) + mask[:, -ref_latents.shape[1]:] = 1 + + b, s, n, d = x.size(0), x.size(1), dit.num_heads, dit.dim // dit.num_heads + pre_compute_freqs = rope_precompute(x.detach().view(b, s, n, d), grid_sizes, torch.cat(dit.freqs, dim=1), start=None) + + x, seq_lens, pre_compute_freqs, mask = dit.inject_motion(x, seq_lens, pre_compute_freqs, mask, motion_latents, add_last_motion=2) + x = x + dit.trainable_cond_mask(mask).to(x.dtype) + + # t_mod + if dit.zero_timestep: + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + e = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + e0 = dit.time_projection(e).unflatten(1, (6, dit.dim)) + if dit.zero_timestep: + e = e[:-1] + zero_e0 = e0[-1:] + e0 = e0[:-1] + e0 = torch.cat([e0.unsqueeze(2), zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)], dim=2) + e0 = [e0, original_seq_len] + # context + context = dit.text_embedding(context) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(dit.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, e0, pre_compute_freqs, + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, e0, pre_compute_freqs, + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)), + x, + use_reentrant=False, + ) + else: + x = block(x, context, e0, pre_compute_freqs) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len) + + x = x[:, :original_seq_len] + x = dit.head(x, e) + x = dit.unpatchify(x, (f, h, w)) + x = torch.cat([origin_ref_latents, x], dim=2) + return x diff --git a/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py b/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py new file mode 100644 index 0000000..73d4a49 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py @@ -0,0 +1,63 @@ +import torch +from PIL import Image +import librosa +from diffsynth import save_video, VideoData, save_video_with_audio +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), +) +num_frames = 81 # 4n+1 +height = 448 +width = 832 + +prompt = "a person is singing" +input_image = Image.open("/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/pose.png").convert("RGB").resize((width, height)) +# s2v audio input, recommend 16kHz sampling rate +audio_path = '/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/sing.MP3' +input_audio, sample_rate = librosa.load(audio_path, sr=16000) + +# Speech-to-video +video = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt="", + seed=0, + num_frames=num_frames, + height=height, + width=width, + audio_sample_rate=sample_rate, + input_audio=input_audio, + num_inference_steps=40, +) +save_video_with_audio(video, "video_with_audio.mp4", audio_path, fps=16, quality=5) + +# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. +pose_video_path = '/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/pose.mp4' +pose_video = VideoData(pose_video_path, height=height, width=width) +pose_video.set_length(num_frames) + +# Speech-to-video with pose +video = pipe( + prompt=prompt, + input_image=input_image, + negative_prompt="", + seed=0, + num_frames=num_frames, + height=height, + width=width, + audio_sample_rate=sample_rate, + input_audio=input_audio, + s2v_pose_video=pose_video, + num_inference_steps=40, +) +save_video_with_audio(video, "video_pose_with_audio.mp4", audio_path, fps=16, quality=5) +save_video(pose_video, "video_pose_input.mp4", fps=16, quality=5) From 8a0bd7c377e48756abd6b63301512d67e6cd4015 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 27 Aug 2025 13:05:53 +0800 Subject: [PATCH 2/6] wans2v lowvram --- diffsynth/models/wav2vec.py | 10 ++++------ diffsynth/pipelines/wan_video_new.py | 9 +++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/diffsynth/models/wav2vec.py b/diffsynth/models/wav2vec.py index f07e99e..e17bd96 100644 --- a/diffsynth/models/wav2vec.py +++ b/diffsynth/models/wav2vec.py @@ -99,19 +99,17 @@ class WanS2VAudioEncoder(torch.nn.Module): self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config)) self.video_rate = 30 - def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32): - input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(self.model.dtype) + def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'): + input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device) # retrieve logits & take argmax - res = self.model(input_values.to(self.model.device), output_hidden_states=True) + res = self.model(input_values, output_hidden_states=True) if return_all_layers: feat = torch.cat(res.hidden_states) else: feat = res.hidden_states[-1] feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) - - z = feat.to(dtype) - return z + return feat def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): num_layers, audio_frame_num, audio_dim = audio_embed.shape diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index ff3c4bd..16df7f4 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -50,9 +50,9 @@ class WanVideoPipeline(BasePipeline): self.units = [ WanVideoUnit_ShapeChecker(), WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), WanVideoUnit_S2V(), WanVideoUnit_InputVideoEmbedder(), - WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedderVAE(), WanVideoUnit_ImageEmbedderCLIP(), WanVideoUnit_ImageEmbedderFused(), @@ -266,13 +266,14 @@ class WanVideoPipeline(BasePipeline): module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.LayerNorm: AutoWrappedModule, + torch.nn.Conv1d: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", - computation_dtype=dtype, + computation_dtype=self.torch_dtype, computation_device=self.device, ), ) @@ -905,14 +906,14 @@ class WanVideoUnit_S2V(PipelineUnit): def __init__(self): super().__init__( take_over=True, - onload_model_names=("audio_encoder", "vae", ) + onload_model_names=("audio_encoder", "vae",) ) def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames): if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None: return {} pipe.load_models_to_device(["audio_encoder"]) - z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True) + z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True, dtype=pipe.torch_dtype, device=pipe.device) audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps( z, fps=16, batch_frames=num_frames - 1, m=0 ) From 4147473c811b076cbbf32886eae270eab91834cd Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 27 Aug 2025 16:18:22 +0800 Subject: [PATCH 3/6] wans2v refactor --- diffsynth/models/wan_video_dit_s2v.py | 221 +++++++++++++++----------- diffsynth/pipelines/wan_video_new.py | 118 ++++++-------- 2 files changed, 183 insertions(+), 156 deletions(-) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 3121c98..b0016df 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Tuple from .utils import hash_state_dict_keys -from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, SelfAttention, Head, CrossAttention +from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d def torch_dfs(model: nn.Module, parent_name='root'): @@ -24,22 +24,6 @@ def torch_dfs(model: nn.Module, parent_name='root'): return modules, module_names -def rope_apply(x, freqs): - n, c = x.size(2), x.size(3) // 2 - # loop over samples - output = [] - for i, _ in enumerate(x): - s = x.size(1) - x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) - freqs_i = freqs[i, :s] - # apply rotary embedding - x_i = torch.view_as_real(x_i * freqs_i).flatten(2) - x_i = torch.cat([x_i, x[i, s:]]) - # append to collection - output.append(x_i) - return torch.stack(output).to(x.dtype) - - def rope_precompute(x, grid_sizes, freqs, start=None): b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 @@ -135,11 +119,8 @@ class MotionEncoder_tc(nn.Module): self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) def forward(self, x): @@ -358,71 +339,21 @@ class CausalAudioEncoder(nn.Module): return res # b f n dim -class WanS2VSelfAttention(SelfAttention): - - def forward(self, x, freqs): - b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim - q = self.norm_q(self.q(x)).view(b, s, n, d) - k = self.norm_k(self.k(x)).view(b, s, n, d) - v = self.v(x) - q = rope_apply(q, freqs) - k = rope_apply(k, freqs) - x = self.attn(q.view(b, s, n * d), k.view(b, s, n * d), v) - return self.o(x) - - class WanS2VDiTBlock(DiTBlock): - def __init__(self, dim, num_heads, ffn_dim, eps=1e-6, has_image_input=False): - super().__init__(has_image_input=has_image_input, dim=dim, num_heads=num_heads, ffn_dim=ffn_dim, eps=eps) - self.self_attn = WanS2VSelfAttention(dim, num_heads, eps) - - def forward(self, x, context, e, freqs): - seg_idx = e[1].item() - seg_idx = min(max(0, seg_idx), x.size(1)) - seg_idx = [0, seg_idx, x.size(1)] - e = e[0] - modulation = self.modulation.unsqueeze(2).to(dtype=e.dtype, device=e.device) - e = (modulation + e).chunk(6, dim=1) - e = [element.squeeze(1) for element in e] - norm_x = self.norm1(x) - parts = [] - for i in range(2): - parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1]) - norm_x = torch.cat(parts, dim=1) - # self-attention - y = self.self_attn(norm_x, freqs) - z = [] - for i in range(2): - z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1]) - y = torch.cat(z, dim=1) - x = x + y - - # cross-attention & ffn function - def cross_attn_ffn(x, context, e): - x = x + self.cross_attn(self.norm3(x), context) - norm2_x = self.norm2(x) - parts = [] - for i in range(2): - parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1]) - norm2_x = torch.cat(parts, dim=1) - y = self.ffn(norm2_x) - z = [] - for i in range(2): - z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1]) - y = torch.cat(z, dim=1) - x = x + y - return x - - x = cross_attn_ffn(x, context, e) - return x - - -class S2VHead(Head): - - def forward(self, x, t_mod): - t_mod = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(1)).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + t_mod[1]) + t_mod[0])) + def forward(self, x, context, t_mod, seq_len_x, freqs): + t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + # t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc. + t_mod = [ + torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1) + for element in t_mod + ] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) return x @@ -472,9 +403,9 @@ class WanS2VModel(torch.nn.Module): self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) - self.blocks = nn.ModuleList([WanS2VDiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) - self.head = S2VHead(dim, out_dim, patch_size, eps) - self.freqs = precompute_freqs_cis_3d(dim // num_heads) + self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) + self.head = Head(dim, out_dim, patch_size, eps) + self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1) self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size) self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain) @@ -516,17 +447,17 @@ class WanS2VModel(torch.nn.Module): else: return flattern_mot, mot_remb - def inject_motion(self, x, seq_lens, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2): + def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2): # inject the motion frames token to the hidden states + # TODO: check drop_motion_frames = False mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion) if len(mot) > 0: x = torch.cat([x, mot[0]], dim=1) - seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long) rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1) mask_input = torch.cat( [mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1 ) - return x, seq_lens, rope_embs, mask_input + return x, rope_embs, mask_input def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len): if block_idx in self.audio_injector.injected_block_id.keys(): @@ -548,6 +479,118 @@ class WanS2VModel(torch.nn.Module): return hidden_states + def cal_audio_emb(self, audio_input, motion_frames=[73, 19]): + audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1) + audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input) + audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone() + merged_audio_emb = audio_emb[:, motion_frames[1]:, :] + return audio_emb_global, merged_audio_emb + + def get_grid_sizes(self, grid_size_x, grid_size_ref): + f, h, w = grid_size_x + rf, rh, rw = grid_size_ref + grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0) + grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]] + grid_sizes_ref = [[ + torch.tensor([30, 0, 0]).unsqueeze(0), + torch.tensor([31, rh, rw]).unsqueeze(0), + torch.tensor([1, rh, rw]).unsqueeze(0), + ]] + return grid_sizes_x + grid_sizes_ref + + def forward( + self, + latents, + timestep, + context, + audio_input, + motion_latents, + pose_cond, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False + ): + origin_ref_latents = latents[:, :, 0:1] + x = latents[:, :, 1:] + + # context embedding + context = self.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input) + + # x and pose_cond + pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond + x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) + seq_len_x = x.shape[1] + + # reference image + ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) + grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw)) + x = torch.cat([x, ref_latents], dim=1) + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute( + x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None + ) + # motion + x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2) + + x = x + self.trainable_cond_mask(mask).to(x.dtype) + + # t_mod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(self.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + context, + t_mod, + seq_len_x, + pre_compute_freqs, + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + context, + t_mod, + seq_len_x, + pre_compute_freqs, + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs) + x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + + x = x[:, :seq_len_x] + x = self.head(x, t[:-1]) + x = self.unpatchify(x, (f, h, w)) + # make compatible with wan video + x = torch.cat([origin_ref_latents, x], dim=2) + return x + @staticmethod def state_dict_converter(): return WanS2VModelStateDictConverter() diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 16df7f4..1362d09 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -927,24 +927,23 @@ class WanVideoUnit_S2V(PipelineUnit): def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride): pipe.load_models_to_device(["vae"]) - # TODO: may support input motion latents + # TODO: may support input motion latents, which related to `drop_motion_frames = False` motion_frames = 73 + lat_motion_frames = (motion_frames + 3) // 4 # 19 motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) - lat_motion_frames = (motion_frames + 3) // 4 motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - return {"motion_latents": motion_latents, "motion_frames": [motion_frames, lat_motion_frames]} + return {"motion_latents": motion_latents} def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride): - pipe.load_models_to_device(["vae"]) if s2v_pose_video is None: - input_video = -torch.ones(1, 3, num_frames, height, width, device=pipe.device, dtype=pipe.torch_dtype) - else: - input_video = pipe.preprocess_video(s2v_pose_video) - # get num_frames-1 frames - input_video = input_video[:, :, :num_frames] - # pad if not enough frames - padding_frames = num_frames - input_video.shape[2] - input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) + return {"pose_cond": None} + pipe.load_models_to_device(["vae"]) + input_video = pipe.preprocess_video(s2v_pose_video) + # get num_frames-1 frames + input_video = input_video[:, :, :num_frames] + # pad if not enough frames + padding_frames = num_frames - input_video.shape[2] + input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) # encode to latents input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) return {"pose_cond": input_latents[:,:,1:]} @@ -1084,7 +1083,6 @@ def model_fn_wan_video( vace_scale = 1.0, audio_input: Optional[torch.Tensor] = None, motion_latents: Optional[torch.Tensor] = None, - motion_frames: Optional[list] = None, pose_cond: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, @@ -1132,10 +1130,10 @@ def model_fn_wan_video( context=context, audio_input=audio_input, motion_latents=motion_latents, - motion_frames=motion_frames, pose_cond=pose_cond, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing=use_gradient_checkpointing, + use_unified_sequence_parallel=use_unified_sequence_parallel, ) if use_unified_sequence_parallel: @@ -1265,62 +1263,47 @@ def model_fn_wans2v( context, audio_input, motion_latents, - motion_frames, pose_cond, use_gradient_checkpointing_offload=False, - use_gradient_checkpointing=False + use_gradient_checkpointing=False, + use_unified_sequence_parallel=False, ): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) origin_ref_latents = latents[:, :, 0:1] - latents = latents[:, :, 1:] + x = latents[:, :, 1:] - audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1) - audio_emb_global, audio_emb = dit.casual_audio_encoder(audio_input) - audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone() - merged_audio_emb = audio_emb[:, motion_frames[1]:, :] + # context embedding + context = dit.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_input) + + # x and pose_cond + pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) + seq_len_x = x.shape[1] # reference image - x = latents - pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond - x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) - - grid_sizes = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0) - seq_lens = torch.tensor([x.size(1)], dtype=torch.long) - grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]] - - ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) - - ref_grid_sizes = [[ - torch.tensor([30, 0, 0]).unsqueeze(0), - torch.tensor([31, rh, rw]).unsqueeze(0), - torch.tensor([1, rh, rw]).unsqueeze(0), - ]] - original_seq_len = seq_lens[0] - seq_lens = seq_lens + torch.tensor([ref_latents.shape[1]], dtype=torch.long) - grid_sizes = grid_sizes + ref_grid_sizes - + ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) + grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) x = torch.cat([x, ref_latents], dim=1) - mask = torch.zeros([1, x.shape[1]], dtype=torch.long, device=x.device) - mask[:, -ref_latents.shape[1]:] = 1 + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None) + # motion + x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2) - b, s, n, d = x.size(0), x.size(1), dit.num_heads, dit.dim // dit.num_heads - pre_compute_freqs = rope_precompute(x.detach().view(b, s, n, d), grid_sizes, torch.cat(dit.freqs, dim=1), start=None) - - x, seq_lens, pre_compute_freqs, mask = dit.inject_motion(x, seq_lens, pre_compute_freqs, mask, motion_latents, add_last_motion=2) x = x + dit.trainable_cond_mask(mask).to(x.dtype) - # t_mod - if dit.zero_timestep: - timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) - e = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) - e0 = dit.time_projection(e).unflatten(1, (6, dit.dim)) - if dit.zero_timestep: - e = e[:-1] - zero_e0 = e0[-1:] - e0 = e0[:-1] - e0 = torch.cat([e0.unsqueeze(2), zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)], dim=2) - e0 = [e0, original_seq_len] - # context - context = dit.text_embedding(context) + # tmod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) def create_custom_forward(module): def custom_forward(*inputs): @@ -1332,31 +1315,32 @@ def model_fn_wans2v( with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, e0, pre_compute_freqs, + x, context, t_mod, seq_len_x, pre_compute_freqs, use_reentrant=False, ) x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)), + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), x, use_reentrant=False, ) elif use_gradient_checkpointing: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, e0, pre_compute_freqs, + x, context, t_mod, seq_len_x, pre_compute_freqs, use_reentrant=False, ) x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)), + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), x, use_reentrant=False, ) else: - x = block(x, context, e0, pre_compute_freqs) - x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len) + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) - x = x[:, :original_seq_len] - x = dit.head(x, e) + x = x[:, :seq_len_x] + x = dit.head(x, t[:-1]) x = dit.unpatchify(x, (f, h, w)) + # make compatible with wan video x = torch.cat([origin_ref_latents, x], dim=2) return x From fdeb363fa2a9f1acda13b47df54da30757fd5d05 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 27 Aug 2025 19:50:33 +0800 Subject: [PATCH 4/6] wans2v usp --- diffsynth/models/wan_video_dit_s2v.py | 9 +++++-- diffsynth/pipelines/wan_video_new.py | 27 +++++++++++++------ .../model_inference/Wan2.1-S2V-14B.py | 26 +++++++++++------- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index b0016df..75b19a4 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -459,10 +459,13 @@ class WanS2VModel(torch.nn.Module): ) return x, rope_embs, mask_input - def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len): + def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False): if block_idx in self.audio_injector.injected_block_id.keys(): audio_attn_id = self.audio_injector.injected_block_id[block_idx] num_frames = audio_emb.shape[1] + if use_unified_sequence_parallel: + from xfuser.core.distributed import get_sp_group + hidden_states = get_sp_group().all_gather(hidden_states, dim=1) input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames) @@ -476,7 +479,9 @@ class WanS2VModel(torch.nn.Module): residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb) residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames) hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out - + if use_unified_sequence_parallel: + from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] return hidden_states def cal_audio_emb(self, audio_input, motion_frames=[73, 19]): diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 1362d09..cef7ae8 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -1284,11 +1284,11 @@ def model_fn_wans2v( # x and pose_cond pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond - x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) - seq_len_x = x.shape[1] + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) + seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel # reference image - ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) + ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) x = torch.cat([x, ref_latents], dim=1) # mask @@ -1305,6 +1305,14 @@ def model_fn_wans2v( t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank() + assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}" + x = torch.chunk(x, world_size, dim=1)[sp_rank] + seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy()) + seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)] + seq_len_x = seq_len_x_list[sp_rank] + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) @@ -1315,7 +1323,7 @@ def model_fn_wans2v( with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs, + x, context, t_mod, seq_len_x, pre_compute_freqs[0], use_reentrant=False, ) x = torch.utils.checkpoint.checkpoint( @@ -1326,7 +1334,7 @@ def model_fn_wans2v( elif use_gradient_checkpointing: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, t_mod, seq_len_x, pre_compute_freqs, + x, context, t_mod, seq_len_x, pre_compute_freqs[0], use_reentrant=False, ) x = torch.utils.checkpoint.checkpoint( @@ -1335,10 +1343,13 @@ def model_fn_wans2v( use_reentrant=False, ) else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs) - x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel) - x = x[:, :seq_len_x] + if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) + + x = x[:, :seq_len_x_global] x = dit.head(x, t[:-1]) x = dit.unpatchify(x, (f, h, w)) # make compatible with wan video diff --git a/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py b/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py index 73d4a49..bb93871 100644 --- a/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py +++ b/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py @@ -1,8 +1,9 @@ import torch from PIL import Image import librosa -from diffsynth import save_video, VideoData, save_video_with_audio +from diffsynth import VideoData, save_video_with_audio from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, @@ -15,21 +16,28 @@ pipe = WanVideoPipeline.from_pretrained( ], audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), ) +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/example_video_dataset", + local_dir="./data/example_video_dataset", + allow_file_pattern=f"wans2v/*" +) + num_frames = 81 # 4n+1 height = 448 width = 832 prompt = "a person is singing" -input_image = Image.open("/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/pose.png").convert("RGB").resize((width, height)) +negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) # s2v audio input, recommend 16kHz sampling rate -audio_path = '/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/sing.MP3' +audio_path = 'data/example_video_dataset/wans2v/sing.MP3' input_audio, sample_rate = librosa.load(audio_path, sr=16000) # Speech-to-video video = pipe( prompt=prompt, input_image=input_image, - negative_prompt="", + negative_prompt=negative_prompt, seed=0, num_frames=num_frames, height=height, @@ -38,18 +46,17 @@ video = pipe( input_audio=input_audio, num_inference_steps=40, ) -save_video_with_audio(video, "video_with_audio.mp4", audio_path, fps=16, quality=5) +save_video_with_audio(video[1:], "video_with_audio.mp4", audio_path, fps=16, quality=5) # s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. -pose_video_path = '/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/pose.mp4' +pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4' pose_video = VideoData(pose_video_path, height=height, width=width) -pose_video.set_length(num_frames) # Speech-to-video with pose video = pipe( prompt=prompt, input_image=input_image, - negative_prompt="", + negative_prompt=negative_prompt, seed=0, num_frames=num_frames, height=height, @@ -59,5 +66,4 @@ video = pipe( s2v_pose_video=pose_video, num_inference_steps=40, ) -save_video_with_audio(video, "video_pose_with_audio.mp4", audio_path, fps=16, quality=5) -save_video(pose_video, "video_pose_input.mp4", fps=16, quality=5) +save_video_with_audio(video[1:], "video_pose_with_audio.mp4", audio_path, fps=16, quality=5) From caa17da5b97d151a133b66c5a2fb08209337fdfb Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 27 Aug 2025 20:05:44 +0800 Subject: [PATCH 5/6] wans2v readme --- README.md | 3 +++ README_zh.md | 3 +++ examples/wanvideo/README.md | 1 + examples/wanvideo/README_zh.md | 1 + 4 files changed, 8 insertions(+) diff --git a/README.md b/README.md index b1c8716..127467d 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) | Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training | |-|-|-|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| @@ -372,6 +373,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## Update History +- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/). + - **August 21, 2025**: [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) is released! Compared to the V1 version, the training dataset has been updated to the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), enabling generated images to better align with the inherent image distribution and style of Qwen-Image. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py). - **August 21, 2025**: We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structure control LoRA model. Following "In Context" routine, it supports various types of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py). diff --git a/README_zh.md b/README_zh.md index a8ff395..ba8197f 100644 --- a/README_zh.md +++ b/README_zh.md @@ -201,6 +201,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| @@ -388,6 +389,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## 更新历史 +- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。 + - **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。 - **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。 diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 4e5195a..456d957 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -48,6 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) | Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |-|-|-|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index bcc076f..1ac53ca 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -48,6 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| From 9cea10cc69a9035b8c13a4167b61cd783a6f7e17 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 28 Aug 2025 10:13:52 +0800 Subject: [PATCH 6/6] minor fix --- README.md | 2 +- README_zh.md | 2 +- diffsynth/models/wan_video_dit_s2v.py | 6 +- examples/wanvideo/README.md | 2 +- examples/wanvideo/README_zh.md | 2 +- .../model_inference/Wan2.1-S2V-14B.py | 69 ------------------- 6 files changed, 7 insertions(+), 76 deletions(-) delete mode 100644 examples/wanvideo/model_inference/Wan2.1-S2V-14B.py diff --git a/README.md b/README.md index 127467d..ce3fea8 100644 --- a/README.md +++ b/README.md @@ -201,7 +201,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) | Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training | |-|-|-|-|-|-|-| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/README_zh.md b/README_zh.md index ba8197f..08f62e9 100644 --- a/README_zh.md +++ b/README_zh.md @@ -201,7 +201,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 75b19a4..fa54591 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -562,7 +562,7 @@ class WanS2VModel(torch.nn.Module): context, t_mod, seq_len_x, - pre_compute_freqs, + pre_compute_freqs[0], use_reentrant=False, ) x = torch.utils.checkpoint.checkpoint( @@ -577,7 +577,7 @@ class WanS2VModel(torch.nn.Module): context, t_mod, seq_len_x, - pre_compute_freqs, + pre_compute_freqs[0], use_reentrant=False, ) x = torch.utils.checkpoint.checkpoint( @@ -586,7 +586,7 @@ class WanS2VModel(torch.nn.Module): use_reentrant=False, ) else: - x = block(x, context, t_mod, seq_len_x, pre_compute_freqs) + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0]) x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) x = x[:, :seq_len_x] diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 456d957..add9fa5 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -48,7 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) | Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |-|-|-|-|-|-|-| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 1ac53ca..57a36c7 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -48,7 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| -|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.1-S2V-14B.py)|-|-|-|-| +|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| diff --git a/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py b/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py deleted file mode 100644 index bb93871..0000000 --- a/examples/wanvideo/model_inference/Wan2.1-S2V-14B.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -from PIL import Image -import librosa -from diffsynth import VideoData, save_video_with_audio -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import dataset_snapshot_download - -pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cuda", - model_configs=[ - ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), - ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"), - ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), - ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), - ], - audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"), -) -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/example_video_dataset", - local_dir="./data/example_video_dataset", - allow_file_pattern=f"wans2v/*" -) - -num_frames = 81 # 4n+1 -height = 448 -width = 832 - -prompt = "a person is singing" -negative_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -input_image = Image.open("data/example_video_dataset/wans2v/pose.png").convert("RGB").resize((width, height)) -# s2v audio input, recommend 16kHz sampling rate -audio_path = 'data/example_video_dataset/wans2v/sing.MP3' -input_audio, sample_rate = librosa.load(audio_path, sr=16000) - -# Speech-to-video -video = pipe( - prompt=prompt, - input_image=input_image, - negative_prompt=negative_prompt, - seed=0, - num_frames=num_frames, - height=height, - width=width, - audio_sample_rate=sample_rate, - input_audio=input_audio, - num_inference_steps=40, -) -save_video_with_audio(video[1:], "video_with_audio.mp4", audio_path, fps=16, quality=5) - -# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps. -pose_video_path = 'data/example_video_dataset/wans2v/pose.mp4' -pose_video = VideoData(pose_video_path, height=height, width=width) - -# Speech-to-video with pose -video = pipe( - prompt=prompt, - input_image=input_image, - negative_prompt=negative_prompt, - seed=0, - num_frames=num_frames, - height=height, - width=width, - audio_sample_rate=sample_rate, - input_audio=input_audio, - s2v_pose_video=pose_video, - num_inference_steps=40, -) -save_video_with_audio(video[1:], "video_pose_with_audio.mp4", audio_path, fps=16, quality=5)