diff --git a/diffsynth/models/wav2vec.py b/diffsynth/models/wav2vec.py index f07e99e..e17bd96 100644 --- a/diffsynth/models/wav2vec.py +++ b/diffsynth/models/wav2vec.py @@ -99,19 +99,17 @@ class WanS2VAudioEncoder(torch.nn.Module): self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config)) self.video_rate = 30 - def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32): - input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(self.model.dtype) + def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'): + input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device) # retrieve logits & take argmax - res = self.model(input_values.to(self.model.device), output_hidden_states=True) + res = self.model(input_values, output_hidden_states=True) if return_all_layers: feat = torch.cat(res.hidden_states) else: feat = res.hidden_states[-1] feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) - - z = feat.to(dtype) - return z + return feat def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): num_layers, audio_frame_num, audio_dim = audio_embed.shape diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index ff3c4bd..16df7f4 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -50,9 +50,9 @@ class WanVideoPipeline(BasePipeline): self.units = [ WanVideoUnit_ShapeChecker(), WanVideoUnit_NoiseInitializer(), + WanVideoUnit_PromptEmbedder(), WanVideoUnit_S2V(), WanVideoUnit_InputVideoEmbedder(), - WanVideoUnit_PromptEmbedder(), WanVideoUnit_ImageEmbedderVAE(), WanVideoUnit_ImageEmbedderCLIP(), WanVideoUnit_ImageEmbedderFused(), @@ -266,13 +266,14 @@ class WanVideoPipeline(BasePipeline): module_map = { torch.nn.Linear: AutoWrappedLinear, torch.nn.LayerNorm: AutoWrappedModule, + torch.nn.Conv1d: AutoWrappedModule, }, module_config = dict( offload_dtype=dtype, offload_device="cpu", onload_dtype=dtype, onload_device="cpu", - computation_dtype=dtype, + computation_dtype=self.torch_dtype, computation_device=self.device, ), ) @@ -905,14 +906,14 @@ class WanVideoUnit_S2V(PipelineUnit): def __init__(self): super().__init__( take_over=True, - onload_model_names=("audio_encoder", "vae", ) + onload_model_names=("audio_encoder", "vae",) ) def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames): if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None: return {} pipe.load_models_to_device(["audio_encoder"]) - z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True) + z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True, dtype=pipe.torch_dtype, device=pipe.device) audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps( z, fps=16, batch_frames=num_frames - 1, m=0 )