support s2v framepack

This commit is contained in:
mi804
2025-09-01 16:48:46 +08:00
parent 1d240994e7
commit 5cee326f92
8 changed files with 220 additions and 56 deletions

View File

@@ -410,7 +410,6 @@ class WanS2VModel(torch.nn.Module):
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
# TODO: refactor dfs
self.audio_injector = AudioInjector_WAN(
all_modules,
all_modules_names,
@@ -449,7 +448,6 @@ class WanS2VModel(torch.nn.Module):
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
# inject the motion frames token to the hidden states
# TODO: check drop_motion_frames = False
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
if len(mot) > 0:
x = torch.cat([x, mot[0]], dim=1)