support s2v framepack

This commit is contained in:
mi804
2025-09-01 16:48:46 +08:00
parent 1d240994e7
commit 5cee326f92
8 changed files with 220 additions and 56 deletions

View File

@@ -183,6 +183,13 @@ class WanS2VAudioEncoder(torch.nn.Module):
return batch_audio_eb, min_batch_num
def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'):
audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device)
audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m)
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)
audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]
return audio_embeds
@staticmethod
def state_dict_converter():
return WanS2VAudioEncoderStateDictConverter()