mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
low_vram
This commit is contained in:
@@ -295,6 +295,43 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
# ACE-Step module maps
|
||||
"diffsynth.models.ace_step_dit.AceStepDiTModel": {
|
||||
"diffsynth.models.ace_step_dit.AceStepDiTLayer": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ace_step_conditioner.AceStepConditionEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ace_step_text_encoder.AceStepTextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ace_step_vae.AceStepVAE": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.ace_step_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
def QwenImageTextEncoder_Module_Map_Updater():
|
||||
|
||||
@@ -522,7 +522,7 @@ class AceStepDiTLayer(nn.Module):
|
||||
# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
|
||||
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb
|
||||
self.scale_shift_table.to(temb.device) + temb
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# Step 1: Self-attention with adaptive layer norm (AdaLN)
|
||||
@@ -889,7 +889,7 @@ class AceStepDiTModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
# Extract scale-shift parameters for adaptive output normalization
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
shift = shift.to(hidden_states.device)
|
||||
scale = scale.to(hidden_states.device)
|
||||
|
||||
|
||||
@@ -594,7 +594,7 @@ class AudioTokenDetokenizer(nn.Module):
|
||||
x = self.embed_tokens(x)
|
||||
x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
|
||||
special_tokens = self.special_tokens.expand(B, T, -1, -1)
|
||||
x = x + special_tokens
|
||||
x = x + special_tokens.to(x.device)
|
||||
x = rearrange(x, "b t p c -> (b t) p c")
|
||||
|
||||
cache_position = torch.arange(0, x.shape[1], device=x.device)
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
@@ -240,3 +240,9 @@ class AceStepVAE(nn.Module):
|
||||
"""Full round-trip: encode → decode."""
|
||||
z = self.encode(sample)
|
||||
return self.decoder(z)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization from all conv layers (for export/inference)."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
|
||||
remove_weight_norm(module)
|
||||
|
||||
@@ -69,6 +69,7 @@ class AceStepPipeline(BasePipeline):
|
||||
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
|
||||
pipe.dit = model_pool.fetch_model("ace_step_dit")
|
||||
pipe.vae = model_pool.fetch_model("ace_step_vae")
|
||||
pipe.vae.remove_weight_norm()
|
||||
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
|
||||
|
||||
if text_tokenizer_config is not None:
|
||||
@@ -372,8 +373,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
)
|
||||
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
||||
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states)
|
||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
|
||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
@@ -468,10 +470,15 @@ class AceStepUnit_AudioCodeDecoder(PipelineUnit):
|
||||
return {"lm_hints": None}
|
||||
|
||||
pipe.load_models_to_device(["tokenizer_model"])
|
||||
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
|
||||
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
|
||||
quantized = pipe.tokenizer_model.tokenizer.quantizer.get_output_from_indices(indices).to(pipe.torch_dtype) # [1, N, 2048]
|
||||
lm_hints = pipe.tokenizer_model.detokenizer(quantized) # [1, N*5, 64]
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
||||
detokenizer = pipe.tokenizer_model.detokenizer
|
||||
|
||||
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
||||
codes = quantizer.get_codes_from_indices(indices)
|
||||
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
||||
quantized = quantizer.project_out(quantized)
|
||||
|
||||
lm_hints = detokenizer(quantized).to(pipe.device)
|
||||
return {"lm_hints": lm_hints}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user