mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
wans2v lowvram
This commit is contained in:
@@ -99,19 +99,17 @@ class WanS2VAudioEncoder(torch.nn.Module):
|
|||||||
self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))
|
self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config))
|
||||||
self.video_rate = 30
|
self.video_rate = 30
|
||||||
|
|
||||||
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32):
|
def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'):
|
||||||
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(self.model.dtype)
|
input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
# retrieve logits & take argmax
|
# retrieve logits & take argmax
|
||||||
res = self.model(input_values.to(self.model.device), output_hidden_states=True)
|
res = self.model(input_values, output_hidden_states=True)
|
||||||
if return_all_layers:
|
if return_all_layers:
|
||||||
feat = torch.cat(res.hidden_states)
|
feat = torch.cat(res.hidden_states)
|
||||||
else:
|
else:
|
||||||
feat = res.hidden_states[-1]
|
feat = res.hidden_states[-1]
|
||||||
feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)
|
feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate)
|
||||||
|
return feat
|
||||||
z = feat.to(dtype)
|
|
||||||
return z
|
|
||||||
|
|
||||||
def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):
|
def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2):
|
||||||
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
num_layers, audio_frame_num, audio_dim = audio_embed.shape
|
||||||
|
|||||||
@@ -50,9 +50,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.units = [
|
self.units = [
|
||||||
WanVideoUnit_ShapeChecker(),
|
WanVideoUnit_ShapeChecker(),
|
||||||
WanVideoUnit_NoiseInitializer(),
|
WanVideoUnit_NoiseInitializer(),
|
||||||
|
WanVideoUnit_PromptEmbedder(),
|
||||||
WanVideoUnit_S2V(),
|
WanVideoUnit_S2V(),
|
||||||
WanVideoUnit_InputVideoEmbedder(),
|
WanVideoUnit_InputVideoEmbedder(),
|
||||||
WanVideoUnit_PromptEmbedder(),
|
|
||||||
WanVideoUnit_ImageEmbedderVAE(),
|
WanVideoUnit_ImageEmbedderVAE(),
|
||||||
WanVideoUnit_ImageEmbedderCLIP(),
|
WanVideoUnit_ImageEmbedderCLIP(),
|
||||||
WanVideoUnit_ImageEmbedderFused(),
|
WanVideoUnit_ImageEmbedderFused(),
|
||||||
@@ -266,13 +266,14 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
module_map = {
|
module_map = {
|
||||||
torch.nn.Linear: AutoWrappedLinear,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
|
torch.nn.Conv1d: AutoWrappedModule,
|
||||||
},
|
},
|
||||||
module_config = dict(
|
module_config = dict(
|
||||||
offload_dtype=dtype,
|
offload_dtype=dtype,
|
||||||
offload_device="cpu",
|
offload_device="cpu",
|
||||||
onload_dtype=dtype,
|
onload_dtype=dtype,
|
||||||
onload_device="cpu",
|
onload_device="cpu",
|
||||||
computation_dtype=dtype,
|
computation_dtype=self.torch_dtype,
|
||||||
computation_device=self.device,
|
computation_device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -905,14 +906,14 @@ class WanVideoUnit_S2V(PipelineUnit):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
take_over=True,
|
take_over=True,
|
||||||
onload_model_names=("audio_encoder", "vae", )
|
onload_model_names=("audio_encoder", "vae",)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames):
|
def process_audio(self, pipe: WanVideoPipeline, input_audio, audio_sample_rate, num_frames):
|
||||||
if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None:
|
if input_audio is None or pipe.audio_encoder is None or pipe.audio_processor is None:
|
||||||
return {}
|
return {}
|
||||||
pipe.load_models_to_device(["audio_encoder"])
|
pipe.load_models_to_device(["audio_encoder"])
|
||||||
z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True)
|
z = pipe.audio_encoder.extract_audio_feat(input_audio, audio_sample_rate, pipe.audio_processor, return_all_layers=True, dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps(
|
audio_embed_bucket, num_repeat = pipe.audio_encoder.get_audio_embed_bucket_fps(
|
||||||
z, fps=16, batch_frames=num_frames - 1, m=0
|
z, fps=16, batch_frames=num_frames - 1, m=0
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user