support s2v framepack

This commit is contained in:
mi804
2025-09-01 16:48:46 +08:00
parent 1d240994e7
commit 5cee326f92
8 changed files with 220 additions and 56 deletions

View File

@@ -410,7 +410,6 @@ class WanS2VModel(torch.nn.Module):
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
# TODO: refactor dfs
self.audio_injector = AudioInjector_WAN(
all_modules,
all_modules_names,
@@ -449,7 +448,6 @@ class WanS2VModel(torch.nn.Module):
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
# inject the motion frames token to the hidden states
# TODO: check drop_motion_frames = False
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
if len(mot) > 0:
x = torch.cat([x, mot[0]], dim=1)

View File

@@ -183,6 +183,13 @@ class WanS2VAudioEncoder(torch.nn.Module):
return batch_audio_eb, min_batch_num
def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'):
audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device)
audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)
audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]
return audio_embeds
@staticmethod
def state_dict_converter():
return WanS2VAudioEncoderStateDictConverter()