mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support s2v framepack
This commit is contained in:
@@ -410,7 +410,6 @@ class WanS2VModel(torch.nn.Module):
|
||||
self.cond_encoder = nn.Conv3d(cond_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.casual_audio_encoder = CausalAudioEncoder(dim=audio_dim, out_dim=dim, num_token=num_audio_token, need_global=enable_adain)
|
||||
all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
|
||||
# TODO: refactor dfs
|
||||
self.audio_injector = AudioInjector_WAN(
|
||||
all_modules,
|
||||
all_modules_names,
|
||||
@@ -449,7 +448,6 @@ class WanS2VModel(torch.nn.Module):
|
||||
|
||||
def inject_motion(self, x, rope_embs, mask_input, motion_latents, drop_motion_frames=True, add_last_motion=2):
|
||||
# inject the motion frames token to the hidden states
|
||||
# TODO: check drop_motion_frames = False
|
||||
mot, mot_remb = self.process_motion_frame_pack(motion_latents, drop_motion_frames=drop_motion_frames, add_last_motion=add_last_motion)
|
||||
if len(mot) > 0:
|
||||
x = torch.cat([x, mot[0]], dim=1)
|
||||
|
||||
@@ -183,6 +183,13 @@ class WanS2VAudioEncoder(torch.nn.Module):
|
||||
|
||||
return batch_audio_eb, min_batch_num
|
||||
|
||||
def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'):
|
||||
audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device)
|
||||
audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m)
|
||||
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)
|
||||
audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]
|
||||
return audio_embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanS2VAudioEncoderStateDictConverter()
|
||||
|
||||
Reference in New Issue
Block a user