wans2v inference

This commit is contained in:
mi804
2025-08-27 11:51:56 +08:00
parent 04e39f7de5
commit b541b9bed2
7 changed files with 1134 additions and 4 deletions

View File

@@ -0,0 +1,579 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
from .utils import hash_state_dict_keys
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, SelfAttention, Head, CrossAttention
def torch_dfs(model: nn.Module, parent_name='root'):
module_names, modules = [], []
current_name = parent_name if parent_name else 'root'
module_names.append(current_name)
modules.append(model)
for name, child in model.named_children():
if parent_name:
child_name = f'{parent_name}.{name}'
else:
child_name = name
child_modules, child_names = torch_dfs(child, child_name)
module_names += child_names
modules += child_modules
return modules, module_names
def rope_apply(x, freqs):
n, c = x.size(2), x.size(3) // 2
# loop over samples
output = []
for i, _ in enumerate(x):
s = x.size(1)
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
freqs_i = freqs[i, :s]
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).to(x.dtype)
def rope_precompute(x, grid_sizes, freqs, start=None):
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
# split freqs
if type(freqs) is list:
trainable_freqs = freqs[1]
freqs = freqs[0]
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = torch.view_as_complex(x.detach().reshape(b, s, n, -1, 2).to(torch.float64))
seq_bucket = [0]
if not type(grid_sizes) is list:
grid_sizes = [grid_sizes]
for g in grid_sizes:
if not type(g) is list:
g = [torch.zeros_like(g), g]
batch_size = g[0].shape[0]
for i in range(batch_size):
if start is None:
f_o, h_o, w_o = g[0][i]
else:
f_o, h_o, w_o = start[i]
f, h, w = g[1][i]
t_f, t_h, t_w = g[2][i]
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
seq_len = int(seq_f * seq_h * seq_w)
if seq_len > 0:
if t_f > 0:
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (t_h / seq_h).item(), (t_w / seq_w).item()
# Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
if f_o >= 0:
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, seq_f).astype(int).tolist()
else:
f_sam = np.linspace(-f_o.item(), (-t_f - f_o).item() + 1, seq_f).astype(int).tolist()
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, seq_h).astype(int).tolist()
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, seq_w).astype(int).tolist()
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][f_sam].conj()
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
freqs_i = torch.cat(
[
freqs_0.expand(seq_f, seq_h, seq_w, -1),
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(seq_f, seq_h, seq_w, -1),
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(seq_f, seq_h, seq_w, -1),
],
dim=-1
).reshape(seq_len, 1, -1)
elif t_f < 0:
freqs_i = trainable_freqs.unsqueeze(1)
# apply rotary embedding
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
seq_bucket.append(seq_bucket[-1] + seq_len)
return output
class CausalConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class MotionEncoder_tc(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.need_global = need_global
self.conv1_local = CausalConv1d(in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
if need_global:
self.conv1_global = CausalConv1d(in_dim, hidden_dim // 4, 3, stride=1)
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
if need_global:
self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs)
self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, 'b t c -> b c t')
x_ori = x.clone()
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1).to(device=x.device, dtype=x.dtype)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
if not self.need_global:
return x_local
x = self.conv1_global(x_ori)
x = rearrange(x, 'b c t -> b t c')
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = self.final_linear(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
return x, x_local
class FramePackMotioner(nn.Module):
def __init__(self, inner_dim=1024, num_heads=16, zip_frame_buckets=[1, 2, 16], drop_mode="drop", *args, **kwargs):
super().__init__(*args, **kwargs)
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
self.zip_frame_buckets = torch.tensor(zip_frame_buckets, dtype=torch.long)
self.inner_dim = inner_dim
self.num_heads = num_heads
self.freqs = torch.cat(precompute_freqs_cis_3d(inner_dim // num_heads), dim=1)
self.drop_mode = drop_mode
def forward(self, motion_latents, add_last_motion=2):
motion_frames = motion_latents[0].shape[1]
mot = []
mot_remb = []
for m in motion_latents:
lat_height, lat_width = m.shape[2], m.shape[3]
padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, lat_width).to(device=m.device, dtype=m.dtype)
overlap_frame = min(padd_lat.shape[1], m.shape[1])
if overlap_frame > 0:
padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
if add_last_motion < 2 and self.drop_mode != "drop":
zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.__len__() - add_last_motion - 1].sum()
padd_lat[:, -zero_end_frame:] = 0
padd_lat = padd_lat.unsqueeze(0)
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum():, :, :].split(
list(self.zip_frame_buckets)[::-1], dim=2
) # 16, 2 ,1
# patchfy
clean_latents_post = self.proj(clean_latents_post).flatten(2).transpose(1, 2)
clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(2).transpose(1, 2)
clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(2).transpose(1, 2)
if add_last_motion < 2 and self.drop_mode == "drop":
clean_latents_post = clean_latents_post[:, :0] if add_last_motion < 2 else clean_latents_post
clean_latents_2x = clean_latents_2x[:, :0] if add_last_motion < 1 else clean_latents_2x
motion_lat = torch.cat([clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
# rope
start_time_id = -(self.zip_frame_buckets[:1].sum())
end_time_id = start_time_id + self.zip_frame_buckets[0]
grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
[
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
]
start_time_id = -(self.zip_frame_buckets[:2].sum())
end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
[
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
]
start_time_id = -(self.zip_frame_buckets[:3].sum())
end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
grid_sizes_4x = [
[
torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 8, lat_width // 8]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.zip_frame_buckets[2], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
]
]
grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
motion_rope_emb = rope_precompute(
motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, self.inner_dim // self.num_heads),
grid_sizes,
self.freqs,
start=None
)
mot.append(motion_lat)
mot_remb.append(motion_rope_emb)
return mot, mot_remb
class AdaLayerNorm(nn.Module):
def __init__(
self,
embedding_dim: int,
output_dim: int,
norm_eps: float = 1e-5,
):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, elementwise_affine=False)
def forward(self, x, temb):
temb = self.linear(F.silu(temb))
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
x = self.norm(x) * (1 + scale) + shift
return x
class AudioInjector_WAN(nn.Module):
def __init__(
self,
all_modules,
all_modules_names,
dim=2048,
num_heads=32,
inject_layer=[0, 27],
enable_adain=False,
adain_dim=2048,
):
super().__init__()
self.injected_block_id = {}
audio_injector_id = 0
for mod_name, mod in zip(all_modules_names, all_modules):
if isinstance(mod, DiTBlock):
for inject_id in inject_layer:
if f'transformer_blocks.{inject_id}' in mod_name:
self.injected_block_id[inject_id] = audio_injector_id
audio_injector_id += 1
self.injector = nn.ModuleList([CrossAttention(
dim=dim,
num_heads=num_heads,
) for _ in range(audio_injector_id)])
self.injector_pre_norm_feat = nn.ModuleList([nn.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6,
) for _ in range(audio_injector_id)])
self.injector_pre_norm_vec = nn.ModuleList([nn.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6,
) for _ in range(audio_injector_id)])
if enable_adain:
self.injector_adain_layers = nn.ModuleList([AdaLayerNorm(output_dim=dim * 2, embedding_dim=adain_dim) for _ in range(audio_injector_id)])
class CausalAudioEncoder(nn.Module):
def __init__(self, dim=5120, num_layers=25, out_dim=2048, num_token=4, need_global=False):
super().__init__()
self.encoder = MotionEncoder_tc(in_dim=dim, hidden_dim=out_dim, num_heads=num_token, need_global=need_global)
weight = torch.ones((1, num_layers, 1, 1)) * 0.01
self.weights = torch.nn.Parameter(weight)
self.act = torch.nn.SiLU()
def forward(self, features):
# features B * num_layers * dim * video_length
weights = self.act(self.weights.to(device=features.device, dtype=features.dtype))
weights_sum = weights.sum(dim=1, keepdims=True)
weighted_feat = ((features * weights) / weights_sum).sum(dim=1) # b dim f
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
res = self.encoder(weighted_feat) # b f n dim
return res # b f n dim
class WanS2VSelfAttention(SelfAttention):
def forward(self, x, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x)
q = rope_apply(q, freqs)
k = rope_apply(k, freqs)
x = self.attn(q.view(b, s, n * d), k.view(b, s, n * d), v)
return self.o(x)
class WanS2VDiTBlock(DiTBlock):
def __init__(self, dim, num_heads, ffn_dim, eps=1e-6, has_image_input=False):
super().__init__(has_image_input=has_image_input, dim=dim, num_heads=num_heads, ffn_dim=ffn_dim, eps=eps)
self.self_attn = WanS2VSelfAttention(dim, num_heads, eps)
def forward(self, x, context, e, freqs):
seg_idx = e[1].item()
seg_idx = min(max(0, seg_idx), x.size(1))
seg_idx = [0, seg_idx, x.size(1)]
e = e[0]
modulation = self.modulation.unsqueeze(2).to(dtype=e.dtype, device=e.device)
e = (modulation + e).chunk(6, dim=1)
e = [element.squeeze(1) for element in e]
norm_x = self.norm1(x)
parts = []
for i in range(2):
parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
norm_x = torch.cat(parts, dim=1)
# self-attention
y = self.self_attn(norm_x, freqs)
z = []
for i in range(2):
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
y = torch.cat(z, dim=1)
x = x + y
# cross-attention & ffn function
def cross_attn_ffn(x, context, e):
x = x + self.cross_attn(self.norm3(x), context)
norm2_x = self.norm2(x)
parts = []
for i in range(2):
parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
norm2_x = torch.cat(parts, dim=1)
y = self.ffn(norm2_x)
z = []
for i in range(2):
z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
y = torch.cat(z, dim=1)
x = x + y
return x
x = cross_attn_ffn(x, context, e)
return x
class S2VHead(Head):
def forward(self, x, t_mod):
t_mod = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + t_mod[1]) + t_mod[0]))
return x
class WanS2VModel(torch.nn.Module):
def __init__(
self,
dim: int,
in_dim: int,
ffn_dim: int,
out_dim: int,
text_dim: int,
freq_dim: int,
eps: float,
patch_size: Tuple[int, int, int],
num_heads: int,
num_layers: int,
cond_dim: int,
audio_dim: int,
num_audio_token: int,
enable_adain: bool = True,
audio_inject_layers: list = [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39],
zero_timestep: bool = True,
add_last_motion: bool = True,
framepack_drop_mode: str = "padd",
fuse_vae_embedding_in_latents: bool = True,
require_vae_embedding: bool = False,
seperated_timestep: bool = False,
require_clip_embedding: bool = False,
):
super().__init__()
self.dim = dim
self.in_dim = in_dim
self.freq_dim = freq_dim
self.patch_size = patch_size
self.num_heads = num_heads
self.enbale_adain = enable_adain
self.add_last_motion = add_last_motion
self.zero_timestep = zero_timestep
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.require_vae_embedding = require_vae_embedding
self.seperated_timestep = seperated_timestep
self.require_clip_embedding = require_clip_embedding
self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'), nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
self.blocks = nn.ModuleList([WanS2VDiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)])
self.head = S2VHead(dim, out_dim, patch_size, eps)
self.freqs = precompute_freqs_cis_3d(dim // num_heads)
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
# TODO: refactor dfs
self.audio_injector = AudioInjector_WAN(
all_modules,
all_modules_names,
dim=dim,
num_heads=num_heads,
inject_layer=audio_inject_layers,
enable_adain=enable_adain,
adain_dim=dim,
)
self.trainable_cond_mask = nn.Embedding(3, dim)
self.frame_packer = FramePackMotioner(inner_dim=dim, num_heads=num_heads, zip_frame_buckets=[1, 2, 16], drop_mode=framepack_drop_mode)
def patchify(self, x: torch.Tensor):
grid_size = x.shape[2:]
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
return x, grid_size # x, grid_size: (f, h, w)
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
return rearrange(
x,
'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
f=grid_size[0],
h=grid_size[1],
w=grid_size[2],
x=self.patch_size[0],
y=self.patch_size[1],
z=self.patch_size[2]
)
def process_motion_frame_pack(self, motion_latents, drop_motion_frames=False, add_last_motion=2):
flattern_mot, mot_remb = self.frame_packer(motion_latents, add_last_motion)
if drop_motion_frames:
return [m[:, :0] for m in flattern_mot], [m[:, :0] for m in mot_remb]
else:
return flattern_mot, mot_remb
def inject_motion(self, x, seq_lens, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
# inject the motion frames token to the hidden states
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
if len(mot) > 0:
x = torch.cat([x, mot[0]], dim=1)
seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long)
rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1)
mask_input = torch.cat(
[mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1
)
return x, seq_lens, rope_embs, mask_input
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len):
if block_idx in self.audio_injector.injected_block_id.keys():
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
num_frames = audio_emb.shape[1]
input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
audio_emb_global = rearrange(audio_emb_global, "b t n c -> (b t) n c")
adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](input_hidden_states, temb=audio_emb_global[:, 0])
attn_hidden_states = adain_hidden_states
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
attn_audio_emb = audio_emb
residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out
return hidden_states
@staticmethod
def state_dict_converter():
return WanS2VModelStateDictConverter()
class WanS2VModelStateDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
config = {}
if hash_state_dict_keys(state_dict) == "966cffdcc52f9c46c391768b27637614":
config = {
"dim": 5120,
"in_dim": 16,
"ffn_dim": 13824,
"out_dim": 16,
"text_dim": 4096,
"freq_dim": 256,
"eps": 1e-06,
"patch_size": (1, 2, 2),
"num_heads": 40,
"num_layers": 40,
"cond_dim": 16,
"audio_dim": 1024,
"num_audio_token": 4,
}
return state_dict, config

199
diffsynth/models/wav2vec.py Normal file
View File

@@ -0,0 +1,199 @@
import math
import numpy as np
import torch
import torch.nn.functional as F
def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None):
required_duration = num_sample / target_fps
required_origin_frames = int(np.ceil(required_duration * original_fps))
if required_duration > total_frames / original_fps:
raise ValueError("required_duration must be less than video length")
if not fixed_start is None and fixed_start >= 0:
start_frame = fixed_start
else:
max_start = total_frames - required_origin_frames
if max_start < 0:
raise ValueError("video length is too short")
start_frame = np.random.randint(0, max_start + 1)
start_time = start_frame / original_fps
end_time = start_time + required_duration
time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
return frame_indices
def linear_interpolation(features, input_fps, output_fps, output_len=None):
"""
features: shape=[1, T, 512]
input_fps: fps for audio, f_a
output_fps: fps for video, f_m
output_len: video length
"""
features = features.transpose(1, 2)
seq_len = features.shape[2] / float(input_fps)
if output_len is None:
output_len = int(seq_len * output_fps)
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len]
return output_features.transpose(1, 2)
class WanS2VAudioEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
from transformers import Wav2Vec2ForCTC, Wav2Vec2Config
config = {
"_name_or_path": "facebook/wav2vec2-large-xlsr-53",
"activation_dropout": 0.05,
"apply_spec_augment": True,
"architectures": ["Wav2Vec2ForCTC"],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": True,
"conv_dim": [512, 512, 512, 512, 512, 512, 512],
"conv_kernel": [10, 3, 3, 3, 3, 2, 2],
"conv_stride": [5, 2, 2, 2, 2, 2, 2],
"ctc_loss_reduction": "mean",
"ctc_zero_infinity": True,
"do_stable_layer_norm": True,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.05,
"final_dropout": 0.0,
"hidden_act": "gelu",
"hidden_dropout": 0.05,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.05,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_space": 1,
"mask_time_other": 0.0,
"mask_time_prob": 0.05,
"mask_time_selection": "static",
"model_type": "wav2vec2",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 0,
"transformers_version": "4.7.0.dev0",
"vocab_size": 33
}
self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))
self.video_rate = 30
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32):
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(self.model.dtype)
# retrieve logits & take argmax
res = self.model(input_values.to(self.model.device), output_hidden_states=True)
if return_all_layers:
feat = torch.cat(res.hidden_states)
else:
feat = res.hidden_states[-1]
feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)
z = feat.to(dtype)
return z
def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
bucket_num = min_batch_num * batch_frames
batch_idx = [stride * i for i in range(bucket_num)]
batch_audio_eb = []
for bi in batch_idx:
if bi < audio_frame_num:
audio_sample_stride = 2
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = \
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
return batch_audio_eb, min_batch_num
def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0):
num_layers, audio_frame_num, audio_dim = audio_embed.shape
if num_layers > 1:
return_all_layers = True
else:
return_all_layers = False
scale = self.video_rate / fps
min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
bucket_num = min_batch_num * batch_frames
padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num
batch_idx = get_sample_indices(
original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0
)
batch_audio_eb = []
audio_sample_stride = int(self.video_rate / fps)
for bi in batch_idx:
if bi < audio_frame_num:
chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride))
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx]
if return_all_layers:
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
else:
frame_audio_embed = audio_embed[0][chosen_idx].flatten()
else:
frame_audio_embed = \
torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
batch_audio_eb.append(frame_audio_embed)
batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0)
return batch_audio_eb, min_batch_num
@staticmethod
def state_dict_converter():
return WanS2VAudioEncoderStateDictConverter()
class WanS2VAudioEncoderStateDictConverter():
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict = {'model.' + k: v for k, v in state_dict.items()}
return state_dict