wans2v inference

This commit is contained in:
mi804
2025-08-27 11:51:56 +08:00
parent 04e39f7de5
commit b541b9bed2
7 changed files with 1134 additions and 4 deletions

View File

@@ -56,11 +56,13 @@ from ..models.stepvideo_vae import StepVideoVAE
from ..models.stepvideo_dit import StepVideoModel from ..models.stepvideo_dit import StepVideoModel
from ..models.wan_video_dit import WanModel from ..models.wan_video_dit import WanModel
from ..models.wan_video_dit_s2v import WanS2VModel
from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38 from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..models.wan_video_vace import VaceWanModel from ..models.wan_video_vace import VaceWanModel
from ..models.wav2vec import WanS2VAudioEncoder
from ..models.step1x_connector import Qwen2Connector from ..models.step1x_connector import Qwen2Connector
@@ -155,6 +157,7 @@ model_loader_configs = [
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "966cffdcc52f9c46c391768b27637614", ["wan_video_dit"], [WanS2VModel], "civitai"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"), (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
@@ -172,6 +175,7 @@ model_loader_configs = [
(None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"), (None, "ed4ea5824d55ec3107b09815e318123a", ["qwen_image_vae"], [QwenImageVAE], "diffusers"),
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"), (None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"), (None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
(None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"),
] ]
huggingface_model_loader_configs = [ huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.

View File

@@ -1 +1 @@
from .video import VideoData, save_video, save_frames from .video import VideoData, save_video, save_frames, merge_video_audio, save_video_with_audio

View File

@@ -2,6 +2,8 @@ import imageio, os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import subprocess
import shutil
class LowMemoryVideo: class LowMemoryVideo:
@@ -146,3 +148,70 @@ def save_frames(frames, save_path):
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
for i, frame in enumerate(tqdm(frames, desc="Saving images")): for i, frame in enumerate(tqdm(frames, desc="Saving images")):
frame.save(os.path.join(save_path, f"{i}.png")) frame.save(os.path.join(save_path, f"{i}.png"))
def merge_video_audio(video_path: str, audio_path: str):
# TODO: may need a in-python implementation to avoid subprocess dependency
"""
Merge the video and audio into a new video, with the duration set to the shorter of the two,
and overwrite the original video file.
Parameters:
video_path (str): Path to the original video file
audio_path (str): Path to the audio file
"""
# check
if not os.path.exists(video_path):
raise FileNotFoundError(f"video file {video_path} does not exist")
if not os.path.exists(audio_path):
raise FileNotFoundError(f"audio file {audio_path} does not exist")
base, ext = os.path.splitext(video_path)
temp_output = f"{base}_temp{ext}"
try:
# create ffmpeg command
command = [
'ffmpeg',
'-y', # overwrite
'-i',
video_path,
'-i',
audio_path,
'-c:v',
'copy', # copy video stream
'-c:a',
'aac', # use AAC audio encoder
'-b:a',
'192k', # set audio bitrate (optional)
'-map',
'0:v:0', # select the first video stream
'-map',
'1:a:0', # select the first audio stream
'-shortest', # choose the shortest duration
temp_output
]
# execute the command
result = subprocess.run(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# check result
if result.returncode != 0:
error_msg = f"FFmpeg execute failed: {result.stderr}"
print(error_msg)
raise RuntimeError(error_msg)
shutil.move(temp_output, video_path)
print(f"Merge completed, saved to {video_path}")
except Exception as e:
if os.path.exists(temp_output):
os.remove(temp_output)
print(f"merge_video_audio failed with error: {e}")
def save_video_with_audio(frames, save_path, audio_path, fps=16, quality=9, ffmpeg_params=None):
save_video(frames, save_path, fps, quality, ffmpeg_params)
merge_video_audio(save_path, audio_path)

View File

@@ -0,0 +1,579 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from .utils import hash_state_dict_keys
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, SelfAttention, Head, CrossAttention
def torch_dfs(model: nn.Module, parent_name='root'):
module_names, modules = [], []
current_name = parent_name if parent_name else 'root'
module_names.append(current_name)
modules.append(model)
for name, child in model.named_children():
if parent_name:
child_name = f'{parent_name}.{name}'
else:
child_name = name
child_modules, child_names = torch_dfs(child, child_name)
module_names += child_names
modules += child_modules
return modules, module_names
def rope_apply(x, freqs):
n, c = x.size(2), x.size(3) // 2
# loop over samples
output = []
for i, _ in enumerate(x):
s = x.size(1)
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
freqs_i = freqs[i, :s]
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).to(x.dtype)
def rope_precompute(x, grid_sizes, freqs, start=None):
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
# split freqs
if type(freqs) is list:
trainable_freqs = freqs[1]
freqs = freqs[0]
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64))
seq_bucket = [0]
if not type(grid_sizes) is list:
grid_sizes = [grid_sizes]
for g in grid_sizes:
if not type(g) is list:
g = [torch.zeros_like(g), g]
batch_size = g[0].shape[0]
for i in range(batch_size):
if start is None:
f_o, h_o, w_o = g[0][i]
else:
f_o, h_o, w_o = start[i]
f, h, w = g[1][i]
t_f, t_h, t_w = g[2][i]
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
seq_len = int(seq_f * seq_h * seq_w)
if seq_len > 0:
if t_f > 0:
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item()
# Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
if f_o >= 0:
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist()
else:
f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist()
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist()
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist()
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj()
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
freqs_i = torch.cat(
[
freqs_0.expand(seq_f, seq_h, seq_w, -1),
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1),
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1),
],
dim=-1
).reshape(seq_len, 1, -1)
elif t_f < 0:
freqs_i = trainable_freqs.unsqueeze(1)
# apply rotary embedding
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
seq_bucket.append(seq_bucket[-1] + seq_len)
return output
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class MotionEncoder_tc(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.need_global = need_global
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
if need_global:
self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1)
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
if need_global:
self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs)
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, 'b t c -> b c t')
x_ori = x.clone()
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
if not self.need_global:
return x_local
x = self.conv1_global(x_ori)
x = rearrange(x, 'b c t -> b t c')
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = self.final_linear(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
return x, x_local
class FramePackMotioner(nn.Module):
def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode="drop", *args, **kwargs):
super().__init__(*args, **kwargs)
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long)
self.inner_dim = inner_dim
self.num_heads = num_heads
self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1)
self.drop_mode = drop_mode
def forward(self, motion_latents, add_last_motion=2):
motion_frames = motion_latents[0].shape[1]
mot = []
mot_remb = []
for m in motion_latents:
lat_height, lat_width = m.shape[2], m.shape[3]
padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype)
overlap_frame = min(padd_lat.shape[1], m.shape[1])
if overlap_frame > 0:
padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
if add_last_motion < 2 and self.drop_mode != "drop":
zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum()
padd_lat[:, -zero_end_frame:] = 0
padd_lat = padd_lat.unsqueeze(0)
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split(
list(self.zip_frame_buckets)[::-1], dim=2
) # 16, 2 ,1
# patchfy
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2)
clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2)
if add_last_motion < 2 and self.drop_mode == "drop":
clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post
clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
# rope
start_time_id = -(self.zip_frame_buckets[:1].sum())
end_time_id = start_time_id + self.zip_frame_buckets[0]
grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
[
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
]
start_time_id = -(self.zip_frame_buckets[:2].sum())
end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
[
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
]
start_time_id = -(self.zip_frame_buckets[:3].sum())
end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
grid_sizes_4x = [
[
torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
]
]
grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
motion_rope_emb = rope_precompute(
motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads),
grid_sizes,
self.freqs,
start=None
)
mot.append(motion_lat)
mot_remb.append(motion_rope_emb)
return mot, mot_remb
class AdaLayerNorm(nn.Module):
def __init__(
self,
embedding_dim: int,
output_dim: int,
norm_eps: float = 1e-5,
):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False)
def forward(self, x, temb):
temb = self.linear(F.silu(temb))
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
x = self.norm(x) * (1 + scale) + shift
return x
class AudioInjector_WAN(nn.Module):
def __init__(
self,
all_modules,
all_modules_names,
dim=2048,
num_heads=32,
inject_layer=[0, 27],
enable_adain=False,
adain_dim=2048,
):
super().__init__()
self.injected_block_id = {}
audio_injector_id = 0
for mod_name, mod in zip(all_modules_names, all_modules):
if isinstance(mod, DiTBlock):
for inject_id in inject_layer:
if f'transformer_blocks.{inject_id}' in mod_name:
self.injected_block_id[inject_id] = audio_injector_id
audio_injector_id += 1
self.injector = nn.ModuleList([CrossAttention(
dim=dim,
num_heads=num_heads,
) for _ in range(audio_injector_id)])
self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6,
) for _ in range(audio_injector_id)])
self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6,
) for _ in range(audio_injector_id)])
if enable_adain:
self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)])
class CausalAudioEncoder(nn.Module):
def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False):
super().__init__()
self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global)
weight = torch.ones((1, num_layers, 1, 1)) * 0.01
self.weights = torch.nn.Parameter(weight)
self.act = torch.nn.SiLU()
def forward(self, features):
# features B * num_layers * dim * video_length
weights = self.act(self.weights.to(device=features.device, dtype=features.dtype))
weights_sum = weights.sum(dim=1, keepdims=True)
weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
res = self.encoder(weighted_feat) # b f n dim
return res # b f n dim
class WanS2VSelfAttention(SelfAttention):
def forward(self, x, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x)
q = rope_apply(q, freqs)
k = rope_apply(k, freqs)
x = self.attn(q.view(b, s, n * d), k.view(b, s, n * d), v)
return self.o(x)
class WanS2VDiTBlock(DiTBlock):
def __init__(self, dim, num_heads, ffn_dim, eps=1e-6, has_image_input=False):
super().__init__(has_image_input=has_image_input, dim=dim, num_heads=num_heads, ffn_dim=ffn_dim, eps=eps)
self.self_attn = WanS2VSelfAttention(dim, num_heads, eps)
def forward(self, x, context, e, freqs):
seg_idx = e[1].item()
seg_idx = min(max(0, seg_idx), x.size(1))
seg_idx = [0, seg_idx, x.size(1)]
e = e[0]
modulation = self.modulation.unsqueeze(2).to(dtype=e.dtype, device=e.device)
e = (modulation + e).chunk(6, dim=1)
e = [element.squeeze(1) for element in e]
norm_x = self.norm1(x)
parts = []
for i in range(2):
parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
norm_x = torch.cat(parts, dim=1)
# self-attention
y = self.self_attn(norm_x, freqs)
z = []
for i in range(2):
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
y = torch.cat(z, dim=1)
x = x + y
# cross-attention & ffn function
def cross_attn_ffn(x, context, e):
x = x + self.cross_attn(self.norm3(x), context)
norm2_x = self.norm2(x)
parts = []
for i in range(2):
parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
norm2_x = torch.cat(parts, dim=1)
y = self.ffn(norm2_x)
z = []
for i in range(2):
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
y = torch.cat(z, dim=1)
x = x + y
return x
x = cross_attn_ffn(x, context, e)
return x
class S2VHead(Head):
def forward(self, x, t_mod):
t_mod = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + t_mod[1]) + t_mod[0]))
return x
class WanS2VModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
cond_dim: int,
audio_dim: int,
num_audio_token: int,
enable_adain: bool = True,
audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
zero_timestep: bool = True,
add_last_motion: bool = True,
framepack_drop_mode: str = "padd",
fuse_vae_embedding_in_latents: bool = True,
require_vae_embedding: bool = False,
seperated_timestep: bool = False,
require_clip_embedding: bool = False,
):
super().__init__()
self.dim = dim
self.in_dim = in_dim
self.freq_dim = freq_dim
self.patch_size = patch_size
self.num_heads = num_heads
self.enbale_adain = enable_adain
self.add_last_motion = add_last_motion
self.zero_timestep = zero_timestep
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.require_vae_embedding = require_vae_embedding
self.seperated_timestep = seperated_timestep
self.require_clip_embedding = require_clip_embedding
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
self.blocks = nn.ModuleList([WanS2VDiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
self.head = S2VHead(dim, out_dim, patch_size, eps)
self.freqs = precompute_freqs_cis_3d(dim // num_heads)
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
# TODO: refactor dfs
self.audio_injector = AudioInjector_WAN(
all_modules,
all_modules_names,
dim=dim,
num_heads=num_heads,
inject_layer=audio_inject_layers,
enable_adain=enable_adain,
adain_dim=dim,
)
self.trainable_cond_mask = nn.Embedding(3, dim)
self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode)
def patchify(self, x: torch.Tensor):
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x,
'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0],
h=grid_size[1],
w=grid_size[2],
x=self.patch_size[0],
y=self.patch_size[1],
z=self.patch_size[2]
)
def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2):
flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion)
if drop_motion_frames:
return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb]
else:
return flattern_mot, mot_remb
def inject_motion(self, x, seq_lens, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
# inject the motion frames token to the hidden states
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
if len(mot) > 0:
x = torch.cat([x, mot[0]], dim=1)
seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long)
rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1)
mask_input = torch.cat(
[mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1
)
return x, seq_lens, rope_embs, mask_input
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len):
if block_idx in self.audio_injector.injected_block_id.keys():
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
num_frames = audio_emb.shape[1]
input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
attn_hidden_states = adain_hidden_states
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
attn_audio_emb = audio_emb
residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out
return hidden_states
@staticmethod
def state_dict_converter():
return WanS2VModelStateDictConverter()
class WanS2VModelStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
config = {}
if hash_state_dict_keys(state_dict) == "966cffdcc52f9c46c391768b27637614":
config = {
"dim": 5120,
"in_dim": 16,
"ffn_dim": 13824,
"out_dim": 16,
"text_dim": 4096,
"freq_dim": 256,
"eps": 1e-06,
"patch_size": (1, 2, 2),
"num_heads": 40,
"num_layers": 40,
"cond_dim": 16,
"audio_dim": 1024,
"num_audio_token": 4,
}
return state_dict, config

199
diffsynth/models/wav2vec.py Normal file
View File

@@ -0,0 +1,199 @@
import math
import numpy as np
import torch
import torch.nn.functional as F
def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None):
required_duration = num_sample / target_fps
required_origin_frames = int(np.ceil(required_duration * original_fps))
if required_duration > total_frames / original_fps:
raise ValueError("required_duration must be less than video length")
if not fixed_start is None and fixed_start >= 0:
start_frame = fixed_start
else:
max_start = total_frames - required_origin_frames
if max_start < 0:
raise ValueError("video length is too short")
start_frame = np.random.randint(0, max_start + 1)
start_time = start_frame / original_fps
end_time = start_time + required_duration
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
return frame_indices
def linear_interpolation(features, input_fps, output_fps, output_len=None):
"""
features: shape=[1, T, 512]
input_fps: fps for audio, f_a
output_fps: fps for video, f_m
output_len: video length
"""
features = features.transpose(1, 2)
seq_len = features.shape[2] / float(input_fps)
if output_len is None:
output_len = int(seq_len * output_fps)
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len]
return output_features.transpose(1, 2)
class WanS2VAudioEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
from transformers import Wav2Vec2ForCTC, Wav2Vec2Config
config = {
"_name_or_path": "facebook/wav2vec2-large-xlsr-53",
"activation_dropout": 0.05,
"apply_spec_augment": True,
"architectures": ["Wav2Vec2ForCTC"],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": True,
"conv_dim": [512, 512, 512, 512, 512, 512, 512],
"conv_kernel": [10, 3, 3, 3, 3, 2, 2],
"conv_stride": [5, 2, 2, 2, 2, 2, 2],
"ctc_loss_reduction": "mean",
"ctc_zero_infinity": True,
"do_stable_layer_norm": True,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.05,
"final_dropout": 0.0,
"hidden_act": "gelu",
"hidden_dropout": 0.05,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.05,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_space": 1,
"mask_time_other": 0.0,
"mask_time_prob": 0.05,
"mask_time_selection": "static",
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 0,
"transformers_version": "4.7.0.dev0",
"vocab_size": 33
}
self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))
self.video_rate = 30
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32):
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(self.model.dtype)
# retrieve logits & take argmax
res = self.model(input_values.to(self.model.device), output_hidden_states=True)
if return_all_layers:
feat = torch.cat(res.hidden_states)
else:
feat = res.hidden_states[-1]
feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)
z = feat.to(dtype)
return z
def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
bucket_num = min_batch_num * batch_frames
batch_idx = [stride * i for i in range(bucket_num)]
batch_audio_eb = []
for bi in batch_idx:
if bi < audio_frame_num:
audio_sample_stride = 2
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = \
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
return batch_audio_eb, min_batch_num
def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
scale = self.video_rate / fps
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
bucket_num = min_batch_num * batch_frames
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num
batch_idx = get_sample_indices(
original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0
)
batch_audio_eb = []
audio_sample_stride = int(self.video_rate / fps)
for bi in batch_idx:
if bi < audio_frame_num:
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = \
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
return batch_audio_eb, min_batch_num
@staticmethod
def state_dict_converter():
return WanS2VAudioEncoderStateDictConverter()
class WanS2VAudioEncoderStateDictConverter():
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict = {'model.' + k: v for k, v in state_dict.items()}
return state_dict

View File

@@ -15,6 +15,7 @@ from typing_extensions import Literal
from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner from ..utils import BasePipeline, ModelConfig, PipelineUnit, PipelineUnitRunner
from ..models import ModelManager, load_state_dict from ..models import ModelManager, load_state_dict
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
from ..models.wan_video_dit_s2v import rope_precompute
from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm
from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_image_encoder import WanImageEncoder
@@ -49,6 +50,7 @@ class WanVideoPipeline(BasePipeline):
self.units = [ self.units = [
WanVideoUnit_ShapeChecker(), WanVideoUnit_ShapeChecker(),
WanVideoUnit_NoiseInitializer(), WanVideoUnit_NoiseInitializer(),
WanVideoUnit_S2V(),
WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(), WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedderVAE(), WanVideoUnit_ImageEmbedderVAE(),
@@ -127,6 +129,8 @@ class WanVideoPipeline(BasePipeline):
torch.nn.LayerNorm: WanAutoCastLayerNorm, torch.nn.LayerNorm: WanAutoCastLayerNorm,
RMSNorm: AutoWrappedModule, RMSNorm: AutoWrappedModule,
torch.nn.Conv2d: AutoWrappedModule, torch.nn.Conv2d: AutoWrappedModule,
torch.nn.Conv1d: AutoWrappedModule,
torch.nn.Embedding: AutoWrappedModule,
}, },
module_config = dict( module_config = dict(
offload_dtype=dtype, offload_dtype=dtype,
@@ -254,6 +258,24 @@ class WanVideoPipeline(BasePipeline):
), ),
vram_limit=vram_limit, vram_limit=vram_limit,
) )
if self.audio_encoder is not None:
# TODO: need check
dtype = next(iter(self.audio_encoder.parameters())).dtype
enable_vram_management(
self.audio_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=dtype,
computation_device=self.device,
),
)
def initialize_usp(self): def initialize_usp(self):
@@ -290,6 +312,7 @@ class WanVideoPipeline(BasePipeline):
device: Union[str, torch.device] = "cuda", device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"), tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*"),
audio_processor_config: ModelConfig = None,
redirect_common_files: bool = True, redirect_common_files: bool = True,
use_usp=False, use_usp=False,
): ):
@@ -332,7 +355,8 @@ class WanVideoPipeline(BasePipeline):
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
pipe.vace = model_manager.fetch_model("wan_video_vace") pipe.vace = model_manager.fetch_model("wan_video_vace")
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
# Size division factor # Size division factor
if pipe.vae is not None: if pipe.vae is not None:
pipe.height_division_factor = pipe.vae.upsampling_factor * 2 pipe.height_division_factor = pipe.vae.upsampling_factor * 2
@@ -342,7 +366,11 @@ class WanVideoPipeline(BasePipeline):
tokenizer_config.download_if_necessary(use_usp=use_usp) tokenizer_config.download_if_necessary(use_usp=use_usp)
pipe.prompter.fetch_models(pipe.text_encoder) pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path) pipe.prompter.fetch_tokenizer(tokenizer_config.path)
if audio_processor_config is not None:
audio_processor_config.download_if_necessary(use_usp=use_usp)
from transformers import Wav2Vec2Processor
pipe.audio_processor = Wav2Vec2Processor.from_pretrained(audio_processor_config.path)
# Unified Sequence Parallel # Unified Sequence Parallel
if use_usp: pipe.enable_usp() if use_usp: pipe.enable_usp()
return pipe return pipe
@@ -361,6 +389,10 @@ class WanVideoPipeline(BasePipeline):
# Video-to-video # Video-to-video
input_video: Optional[list[Image.Image]] = None, input_video: Optional[list[Image.Image]] = None,
denoising_strength: Optional[float] = 1.0, denoising_strength: Optional[float] = 1.0,
# Speech-to-video
input_audio: Optional[str] = None,
audio_sample_rate: Optional[int] = 16000,
s2v_pose_video: Optional[list[Image.Image]] = None,
# ControlNet # ControlNet
control_video: Optional[list[Image.Image]] = None, control_video: Optional[list[Image.Image]] = None,
reference_image: Optional[Image.Image] = None, reference_image: Optional[Image.Image] = None,
@@ -429,6 +461,7 @@ class WanVideoPipeline(BasePipeline):
"motion_bucket_id": motion_bucket_id, "motion_bucket_id": motion_bucket_id,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video,
} }
for unit in self.units: for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -868,6 +901,67 @@ class WanVideoUnit_CfgMerger(PipelineUnit):
return inputs_shared, inputs_posi, inputs_nega return inputs_shared, inputs_posi, inputs_nega
class WanVideoUnit_S2V(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("audio_encoder", "vae", )
)
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames):
if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None:
return {}
pipe.load_models_to_device(["audio_encoder"])
z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True)
audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps(
z, fps=16, batch_frames=num_frames - 1, m=0
)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).to(pipe.device, pipe.torch_dtype)
if len(audio_embed_bucket.shape) == 3:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1)
elif len(audio_embed_bucket.shape) == 4:
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
audio_embed_bucket = audio_embed_bucket[..., 0:num_frames-1]
return {"audio_input": audio_embed_bucket}
def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride):
pipe.load_models_to_device(["vae"])
# TODO: may support input motion latents
motion_frames = 73
motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device)
lat_motion_frames = (motion_frames + 3) // 4
motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"motion_latents": motion_latents, "motion_frames": [motion_frames, lat_motion_frames]}
def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride):
pipe.load_models_to_device(["vae"])
if s2v_pose_video is None:
input_video = -torch.ones(1, 3, num_frames, height, width, device=pipe.device, dtype=pipe.torch_dtype)
else:
input_video = pipe.preprocess_video(s2v_pose_video)
# get num_frames-1 frames
input_video = input_video[:, :, :num_frames]
# pad if not enough frames
padding_frames = num_frames - input_video.shape[2]
input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2)
# encode to latents
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
return {"pose_cond": input_latents[:,:,1:]}
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("input_audio") is None or pipe.audio_encoder is None or pipe.audio_processor is None:
return inputs_shared, inputs_posi, inputs_nega
input_audio, audio_sample_rate, s2v_pose_video, num_frames, height, width = inputs_shared.get("input_audio"), inputs_shared.get("audio_sample_rate"), inputs_shared.get("s2v_pose_video"), inputs_shared.get("num_frames"), inputs_shared.get("height"), inputs_shared.get("width")
tiled, tile_size, tile_stride = inputs_shared.get("tiled"), inputs_shared.get("tile_size"), inputs_shared.get("tile_stride")
audio_input_positive = self.process_audio(pipe, input_audio, audio_sample_rate, num_frames)
inputs_posi.update(audio_input_positive)
inputs_nega.update({"audio_input": 0.0 * audio_input_positive["audio_input"]})
inputs_shared.update(self.process_motion_latents(pipe, height, width, tiled, tile_size, tile_stride))
inputs_shared.update(self.process_pose_cond(pipe, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride))
return inputs_shared, inputs_posi, inputs_nega
class TeaCache: class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id): def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
@@ -987,6 +1081,10 @@ def model_fn_wan_video(
reference_latents = None, reference_latents = None,
vace_context = None, vace_context = None,
vace_scale = 1.0, vace_scale = 1.0,
audio_input: Optional[torch.Tensor] = None,
motion_latents: Optional[torch.Tensor] = None,
motion_frames: Optional[list] = None,
pose_cond: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None, tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False, use_unified_sequence_parallel: bool = False,
motion_bucket_id: Optional[torch.Tensor] = None, motion_bucket_id: Optional[torch.Tensor] = None,
@@ -1024,7 +1122,21 @@ def model_fn_wan_video(
tensor_names=["latents", "y"], tensor_names=["latents", "y"],
batch_size=2 if cfg_merge else 1 batch_size=2 if cfg_merge else 1
) )
# wan2.2 s2v
if audio_input is not None:
return model_fn_wans2v(
dit=dit,
latents=latents,
timestep=timestep,
context=context,
audio_input=audio_input,
motion_latents=motion_latents,
motion_frames=motion_frames,
pose_cond=pose_cond,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
use_gradient_checkpointing=use_gradient_checkpointing,
)
if use_unified_sequence_parallel: if use_unified_sequence_parallel:
import torch.distributed as dist import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank, from xfuser.core.distributed import (get_sequence_parallel_rank,
@@ -1143,3 +1255,107 @@ def model_fn_wan_video(
f -= 1 f -= 1
x = dit.unpatchify(x, (f, h, w)) x = dit.unpatchify(x, (f, h, w))
return x return x
def model_fn_wans2v(
dit,
latents,
timestep,
context,
audio_input,
motion_latents,
motion_frames,
pose_cond,
use_gradient_checkpointing_offload=False,
use_gradient_checkpointing=False
):
origin_ref_latents = latents[:, :, 0:1]
latents = latents[:, :, 1:]
audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1)
audio_emb_global, audio_emb = dit.casual_audio_encoder(audio_input)
audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone()
merged_audio_emb = audio_emb[:, motion_frames[1]:, :]
# reference image
x = latents
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
grid_sizes = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0)
seq_lens = torch.tensor([x.size(1)], dtype=torch.long)
grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
ref_grid_sizes = [[
torch.tensor([30, 0, 0]).unsqueeze(0),
torch.tensor([31, rh, rw]).unsqueeze(0),
torch.tensor([1, rh, rw]).unsqueeze(0),
]]
original_seq_len = seq_lens[0]
seq_lens = seq_lens + torch.tensor([ref_latents.shape[1]], dtype=torch.long)
grid_sizes = grid_sizes + ref_grid_sizes
x = torch.cat([x, ref_latents], dim=1)
mask = torch.zeros([1, x.shape[1]], dtype=torch.long, device=x.device)
mask[:, -ref_latents.shape[1]:] = 1
b, s, n, d = x.size(0), x.size(1), dit.num_heads, dit.dim // dit.num_heads
pre_compute_freqs = rope_precompute(x.detach().view(b, s, n, d), grid_sizes, torch.cat(dit.freqs, dim=1), start=None)
x, seq_lens, pre_compute_freqs, mask = dit.inject_motion(x, seq_lens, pre_compute_freqs, mask, motion_latents, add_last_motion=2)
x = x + dit.trainable_cond_mask(mask).to(x.dtype)
# t_mod
if dit.zero_timestep:
timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)])
e = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
e0 = dit.time_projection(e).unflatten(1, (6, dit.dim))
if dit.zero_timestep:
e = e[:-1]
zero_e0 = e0[-1:]
e0 = e0[:-1]
e0 = torch.cat([e0.unsqueeze(2), zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)], dim=2)
e0 = [e0, original_seq_len]
# context
context = dit.text_embedding(context)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(dit.blocks):
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, e0, pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
x,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, e0, pre_compute_freqs,
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)),
x,
use_reentrant=False,
)
else:
x = block(x, context, e0, pre_compute_freqs)
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)
x = x[:, :original_seq_len]
x = dit.head(x, e)
x = dit.unpatchify(x, (f, h, w))
x = torch.cat([origin_ref_latents, x], dim=2)
return x

View File

@@ -0,0 +1,63 @@
import torch
from PIL import Image
import librosa
from diffsynth import save_video, VideoData, save_video_with_audio
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors"),
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
],
audio_processor_config=ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/"),
)
num_frames = 81 # 4n+1
height = 448
width = 832
prompt = "a person is singing"
input_image = Image.open("/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/pose.png").convert("RGB").resize((width, height))
# s2v audio input, recommend 16kHz sampling rate
audio_path = '/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/sing.MP3'
input_audio, sample_rate = librosa.load(audio_path, sr=16000)
# Speech-to-video
video = pipe(
prompt=prompt,
input_image=input_image,
negative_prompt="",
seed=0,
num_frames=num_frames,
height=height,
width=width,
audio_sample_rate=sample_rate,
input_audio=input_audio,
num_inference_steps=40,
)
save_video_with_audio(video, "video_with_audio.mp4", audio_path, fps=16, quality=5)
# s2v will use the first (num_frames) frames as reference. height and width must be the same as input_image. And fps should be 16, the same as output video fps.
pose_video_path = '/mnt/nas1/zhanghong/project/aigc/Wan2.2_s2v/examples/pose.mp4'
pose_video = VideoData(pose_video_path, height=height, width=width)
pose_video.set_length(num_frames)
# Speech-to-video with pose
video = pipe(
prompt=prompt,
input_image=input_image,
negative_prompt="",
seed=0,
num_frames=num_frames,
height=height,
width=width,
audio_sample_rate=sample_rate,
input_audio=input_audio,
s2v_pose_video=pose_video,
num_inference_steps=40,
)
save_video_with_audio(video, "video_pose_with_audio.mp4", audio_path, fps=16, quality=5)
save_video(pose_video, "video_pose_input.mp4", fps=16, quality=5)