mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 06:46:13 +00:00
low_vram
This commit is contained in:
@@ -69,6 +69,7 @@ class AceStepPipeline(BasePipeline):
|
||||
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
|
||||
pipe.dit = model_pool.fetch_model("ace_step_dit")
|
||||
pipe.vae = model_pool.fetch_model("ace_step_vae")
|
||||
pipe.vae.remove_weight_norm()
|
||||
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
|
||||
|
||||
if text_tokenizer_config is not None:
|
||||
@@ -372,8 +373,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
)
|
||||
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
||||
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states)
|
||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
|
||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
@@ -468,10 +470,15 @@ class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
||||
return {"lm_hints": None}
|
||||
|
||||
pipe.load_models_to_device(["tokenizer_model"])
|
||||
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
|
||||
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
|
||||
quantized = pipe.tokenizer_model.tokenizer.quantizer.get_output_from_indices(indices).to(pipe.torch_dtype) # [1, N, 2048]
|
||||
lm_hints = pipe.tokenizer_model.detokenizer(quantized) # [1, N*5, 64]
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
||||
detokenizer = pipe.tokenizer_model.detokenizer
|
||||
|
||||
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
||||
codes = quantizer.get_codes_from_indices(indices)
|
||||
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
||||
quantized = quantizer.project_out(quantized)
|
||||
|
||||
lm_hints = detokenizer(quantized).to(pipe.device)
|
||||
return {"lm_hints": lm_hints}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user