diff --git a/diffsynth/models/wan_video_dit_s2v.py b/diffsynth/models/wan_video_dit_s2v.py index 3121c98..b0016df 100644 --- a/diffsynth/models/wan_video_dit_s2v.py +++ b/diffsynth/models/wan_video_dit_s2v.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Tuple from .utils import hash_state_dict_keys -from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, SelfAttention, Head, CrossAttention +from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d def torch_dfs(model: nn.Module, parent_name='root'): @@ -24,22 +24,6 @@ def torch_dfs(model: nn.Module, parent_name='root'): return modules, module_names -def rope_apply(x, freqs): - n, c = x.size(2), x.size(3) // 2 - # loop over samples - output = [] - for i, _ in enumerate(x): - s = x.size(1) - x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) - freqs_i = freqs[i, :s] - # apply rotary embedding - x_i = torch.view_as_real(x_i * freqs_i).flatten(2) - x_i = torch.cat([x_i, x[i, s:]]) - # append to collection - output.append(x_i) - return torch.stack(output).to(x.dtype) - - def rope_precompute(x, grid_sizes, freqs, start=None): b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2 @@ -135,11 +119,8 @@ class MotionEncoder_tc(nn.Module): self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) self.norm1 = nn.LayerNorm(hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.norm2 = nn.LayerNorm(hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) - self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) def forward(self, x): @@ -358,71 +339,21 @@ class CausalAudioEncoder(nn.Module): return res # b f n dim -class WanS2VSelfAttention(SelfAttention): - - def forward(self, x, freqs): - b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim - q = self.norm_q(self.q(x)).view(b, s, n, d) - k = self.norm_k(self.k(x)).view(b, s, n, d) - v = self.v(x) - q = rope_apply(q, freqs) - k = rope_apply(k, freqs) - x = self.attn(q.view(b, s, n * d), k.view(b, s, n * d), v) - return self.o(x) - - class WanS2VDiTBlock(DiTBlock): - def __init__(self, dim, num_heads, ffn_dim, eps=1e-6, has_image_input=False): - super().__init__(has_image_input=has_image_input, dim=dim, num_heads=num_heads, ffn_dim=ffn_dim, eps=eps) - self.self_attn = WanS2VSelfAttention(dim, num_heads, eps) - - def forward(self, x, context, e, freqs): - seg_idx = e[1].item() - seg_idx = min(max(0, seg_idx), x.size(1)) - seg_idx = [0, seg_idx, x.size(1)] - e = e[0] - modulation = self.modulation.unsqueeze(2).to(dtype=e.dtype, device=e.device) - e = (modulation + e).chunk(6, dim=1) - e = [element.squeeze(1) for element in e] - norm_x = self.norm1(x) - parts = [] - for i in range(2): - parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1]) - norm_x = torch.cat(parts, dim=1) - # self-attention - y = self.self_attn(norm_x, freqs) - z = [] - for i in range(2): - z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1]) - y = torch.cat(z, dim=1) - x = x + y - - # cross-attention & ffn function - def cross_attn_ffn(x, context, e): - x = x + self.cross_attn(self.norm3(x), context) - norm2_x = self.norm2(x) - parts = [] - for i in range(2): - parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] * (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1]) - norm2_x = torch.cat(parts, dim=1) - y = self.ffn(norm2_x) - z = [] - for i in range(2): - z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1]) - y = torch.cat(z, dim=1) - x = x + y - return x - - x = cross_attn_ffn(x, context, e) - return x - - -class S2VHead(Head): - - def forward(self, x, t_mod): - t_mod = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod.unsqueeze(1)).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + t_mod[1]) + t_mod[0])) + def forward(self, x, context, t_mod, seq_len_x, freqs): + t_mod = (self.modulation.unsqueeze(2).to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1) + # t_mod[:, :, 0] for x, t_mod[:, :, 1] for other like ref, motion, etc. + t_mod = [ + torch.cat([element[:, :, 0].expand(1, seq_len_x, x.shape[-1]), element[:, :, 1].expand(1, x.shape[1] - seq_len_x, x.shape[-1])], dim=1) + for element in t_mod + ] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = t_mod + input_x = modulate(self.norm1(x), shift_msa, scale_msa) + x = self.gate(x, gate_msa, self.self_attn(input_x, freqs)) + x = x + self.cross_attn(self.norm3(x), context) + input_x = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = self.gate(x, gate_mlp, self.ffn(input_x)) return x @@ -472,9 +403,9 @@ class WanS2VModel(torch.nn.Module): self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) - self.blocks = nn.ModuleList([WanS2VDiTBlock(dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) - self.head = S2VHead(dim, out_dim, patch_size, eps) - self.freqs = precompute_freqs_cis_3d(dim // num_heads) + self.blocks = nn.ModuleList([WanS2VDiTBlock(False, dim, num_heads, ffn_dim, eps) for _ in range(num_layers)]) + self.head = Head(dim, out_dim, patch_size, eps) + self.freqs = torch.cat(precompute_freqs_cis_3d(dim // num_heads), dim=1) self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size) self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain) @@ -516,17 +447,17 @@ class WanS2VModel(torch.nn.Module): else: return flattern_mot, mot_remb - def inject_motion(self, x, seq_lens, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2): + def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2): # inject the motion frames token to the hidden states + # TODO: check drop_motion_frames = False mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion) if len(mot) > 0: x = torch.cat([x, mot[0]], dim=1) - seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot], dtype=torch.long) rope_embs = torch.cat([rope_embs, mot_remb[0]], dim=1) mask_input = torch.cat( [mask_input, 2 * torch.ones([1, x.shape[1] - mask_input.shape[1]], device=mask_input.device, dtype=mask_input.dtype)], dim=1 ) - return x, seq_lens, rope_embs, mask_input + return x, rope_embs, mask_input def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len): if block_idx in self.audio_injector.injected_block_id.keys(): @@ -548,6 +479,118 @@ class WanS2VModel(torch.nn.Module): return hidden_states + def cal_audio_emb(self, audio_input, motion_frames=[73, 19]): + audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1) + audio_emb_global, audio_emb = self.casual_audio_encoder(audio_input) + audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone() + merged_audio_emb = audio_emb[:, motion_frames[1]:, :] + return audio_emb_global, merged_audio_emb + + def get_grid_sizes(self, grid_size_x, grid_size_ref): + f, h, w = grid_size_x + rf, rh, rw = grid_size_ref + grid_sizes_x = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0) + grid_sizes_x = [[torch.zeros_like(grid_sizes_x), grid_sizes_x, grid_sizes_x]] + grid_sizes_ref = [[ + torch.tensor([30, 0, 0]).unsqueeze(0), + torch.tensor([31, rh, rw]).unsqueeze(0), + torch.tensor([1, rh, rw]).unsqueeze(0), + ]] + return grid_sizes_x + grid_sizes_ref + + def forward( + self, + latents, + timestep, + context, + audio_input, + motion_latents, + pose_cond, + use_gradient_checkpointing_offload=False, + use_gradient_checkpointing=False + ): + origin_ref_latents = latents[:, :, 0:1] + x = latents[:, :, 1:] + + # context embedding + context = self.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = self.cal_audio_emb(audio_input) + + # x and pose_cond + pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond + x, (f, h, w) = self.patchify(self.patch_embedding(x) + self.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) + seq_len_x = x.shape[1] + + # reference image + ref_latents, (rf, rh, rw) = self.patchify(self.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) + grid_sizes = self.get_grid_sizes((f, h, w), (rf, rh, rw)) + x = torch.cat([x, ref_latents], dim=1) + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute( + x.detach().view(1, x.size(1), self.num_heads, self.dim // self.num_heads), grid_sizes, self.freqs, start=None + ) + # motion + x, pre_compute_freqs, mask = self.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2) + + x = x + self.trainable_cond_mask(mask).to(x.dtype) + + # t_mod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)).unsqueeze(2).transpose(0, 2) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + for block_id, block in enumerate(self.blocks): + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + context, + t_mod, + seq_len_x, + pre_compute_freqs, + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + elif use_gradient_checkpointing: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, + context, + t_mod, + seq_len_x, + pre_compute_freqs, + use_reentrant=False, + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(lambda x: self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), + x, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs) + x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) + + x = x[:, :seq_len_x] + x = self.head(x, t[:-1]) + x = self.unpatchify(x, (f, h, w)) + # make compatible with wan video + x = torch.cat([origin_ref_latents, x], dim=2) + return x + @staticmethod def state_dict_converter(): return WanS2VModelStateDictConverter() diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 16df7f4..1362d09 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -927,24 +927,23 @@ class WanVideoUnit_S2V(PipelineUnit): def process_motion_latents(self, pipe: WanVideoPipeline, height, width, tiled, tile_size, tile_stride): pipe.load_models_to_device(["vae"]) - # TODO: may support input motion latents + # TODO: may support input motion latents, which related to `drop_motion_frames = False` motion_frames = 73 + lat_motion_frames = (motion_frames + 3) // 4 # 19 motion_latents = torch.zeros([1, 3, motion_frames, height, width], dtype=pipe.torch_dtype, device=pipe.device) - lat_motion_frames = (motion_frames + 3) // 4 motion_latents = pipe.vae.encode(motion_latents, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - return {"motion_latents": motion_latents, "motion_frames": [motion_frames, lat_motion_frames]} + return {"motion_latents": motion_latents} def process_pose_cond(self, pipe: WanVideoPipeline, s2v_pose_video, num_frames, height, width, tiled, tile_size, tile_stride): - pipe.load_models_to_device(["vae"]) if s2v_pose_video is None: - input_video = -torch.ones(1, 3, num_frames, height, width, device=pipe.device, dtype=pipe.torch_dtype) - else: - input_video = pipe.preprocess_video(s2v_pose_video) - # get num_frames-1 frames - input_video = input_video[:, :, :num_frames] - # pad if not enough frames - padding_frames = num_frames - input_video.shape[2] - input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) + return {"pose_cond": None} + pipe.load_models_to_device(["vae"]) + input_video = pipe.preprocess_video(s2v_pose_video) + # get num_frames-1 frames + input_video = input_video[:, :, :num_frames] + # pad if not enough frames + padding_frames = num_frames - input_video.shape[2] + input_video = torch.cat([input_video, -torch.ones(1, 3, padding_frames, height, width, device=input_video.device, dtype=input_video.dtype)], dim=2) # encode to latents input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) return {"pose_cond": input_latents[:,:,1:]} @@ -1084,7 +1083,6 @@ def model_fn_wan_video( vace_scale = 1.0, audio_input: Optional[torch.Tensor] = None, motion_latents: Optional[torch.Tensor] = None, - motion_frames: Optional[list] = None, pose_cond: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, @@ -1132,10 +1130,10 @@ def model_fn_wan_video( context=context, audio_input=audio_input, motion_latents=motion_latents, - motion_frames=motion_frames, pose_cond=pose_cond, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, use_gradient_checkpointing=use_gradient_checkpointing, + use_unified_sequence_parallel=use_unified_sequence_parallel, ) if use_unified_sequence_parallel: @@ -1265,62 +1263,47 @@ def model_fn_wans2v( context, audio_input, motion_latents, - motion_frames, pose_cond, use_gradient_checkpointing_offload=False, - use_gradient_checkpointing=False + use_gradient_checkpointing=False, + use_unified_sequence_parallel=False, ): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) origin_ref_latents = latents[:, :, 0:1] - latents = latents[:, :, 1:] + x = latents[:, :, 1:] - audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames[0]), audio_input], dim=-1) - audio_emb_global, audio_emb = dit.casual_audio_encoder(audio_input) - audio_emb_global = audio_emb_global[:, motion_frames[1]:].clone() - merged_audio_emb = audio_emb[:, motion_frames[1]:, :] + # context embedding + context = dit.text_embedding(context) + + # audio encode + audio_emb_global, merged_audio_emb = dit.cal_audio_emb(audio_input) + + # x and pose_cond + pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond + x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120]) + seq_len_x = x.shape[1] # reference image - x = latents - pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond - x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) - - grid_sizes = torch.tensor([f, h, w], dtype=torch.long).unsqueeze(0) - seq_lens = torch.tensor([x.size(1)], dtype=torch.long) - grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]] - - ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) - - ref_grid_sizes = [[ - torch.tensor([30, 0, 0]).unsqueeze(0), - torch.tensor([31, rh, rw]).unsqueeze(0), - torch.tensor([1, rh, rw]).unsqueeze(0), - ]] - original_seq_len = seq_lens[0] - seq_lens = seq_lens + torch.tensor([ref_latents.shape[1]], dtype=torch.long) - grid_sizes = grid_sizes + ref_grid_sizes - + ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120]) + grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) x = torch.cat([x, ref_latents], dim=1) - mask = torch.zeros([1, x.shape[1]], dtype=torch.long, device=x.device) - mask[:, -ref_latents.shape[1]:] = 1 + # mask + mask = torch.cat([torch.zeros([1, seq_len_x]), torch.ones([1, ref_latents.shape[1]])], dim=1).to(torch.long).to(x.device) + # freqs + pre_compute_freqs = rope_precompute(x.detach().view(1, x.size(1), dit.num_heads, dit.dim // dit.num_heads), grid_sizes, dit.freqs, start=None) + # motion + x, pre_compute_freqs, mask = dit.inject_motion(x, pre_compute_freqs, mask, motion_latents, add_last_motion=2) - b, s, n, d = x.size(0), x.size(1), dit.num_heads, dit.dim // dit.num_heads - pre_compute_freqs = rope_precompute(x.detach().view(b, s, n, d), grid_sizes, torch.cat(dit.freqs, dim=1), start=None) - - x, seq_lens, pre_compute_freqs, mask = dit.inject_motion(x, seq_lens, pre_compute_freqs, mask, motion_latents, add_last_motion=2) x = x + dit.trainable_cond_mask(mask).to(x.dtype) - # t_mod - if dit.zero_timestep: - timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) - e = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) - e0 = dit.time_projection(e).unflatten(1, (6, dit.dim)) - if dit.zero_timestep: - e = e[:-1] - zero_e0 = e0[-1:] - e0 = e0[:-1] - e0 = torch.cat([e0.unsqueeze(2), zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)], dim=2) - e0 = [e0, original_seq_len] - # context - context = dit.text_embedding(context) + # tmod + timestep = torch.cat([timestep, torch.zeros([1], dtype=timestep.dtype, device=timestep.device)]) + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2) def create_custom_forward(module): def custom_forward(*inputs): @@ -1332,31 +1315,32 @@ def model_fn_wans2v( with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, e0, pre_compute_freqs, + x, context, t_mod, seq_len_x, pre_compute_freqs, use_reentrant=False, ) x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)), + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), x, use_reentrant=False, ) elif use_gradient_checkpointing: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - x, context, e0, pre_compute_freqs, + x, context, t_mod, seq_len_x, pre_compute_freqs, use_reentrant=False, ) x = torch.utils.checkpoint.checkpoint( - create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len)), + create_custom_forward(lambda x: dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)), x, use_reentrant=False, ) else: - x = block(x, context, e0, pre_compute_freqs) - x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, original_seq_len) + x = block(x, context, t_mod, seq_len_x, pre_compute_freqs) + x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x) - x = x[:, :original_seq_len] - x = dit.head(x, e) + x = x[:, :seq_len_x] + x = dit.head(x, t[:-1]) x = dit.unpatchify(x, (f, h, w)) + # make compatible with wan video x = torch.cat([origin_ref_latents, x], dim=2) return x