mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:39:43 +00:00
wans2v lowvram
This commit is contained in:
@@ -99,19 +99,17 @@ class WanS2VAudioEncoder(torch.nn.Module):
|
||||
self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))
|
||||
self.video_rate = 30
|
||||
|
||||
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32):
|
||||
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(self.model.dtype)
|
||||
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'):
|
||||
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device)
|
||||
|
||||
# retrieve logits & take argmax
|
||||
res = self.model(input_values.to(self.model.device), output_hidden_states=True)
|
||||
res = self.model(input_values, output_hidden_states=True)
|
||||
if return_all_layers:
|
||||
feat = torch.cat(res.hidden_states)
|
||||
else:
|
||||
feat = res.hidden_states[-1]
|
||||
feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)
|
||||
|
||||
z = feat.to(dtype)
|
||||
return z
|
||||
return feat
|
||||
|
||||
def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):
|
||||
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||
|
||||
Reference in New Issue
Block a user