mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
vram
This commit is contained in:
@@ -155,4 +155,24 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.flux2_vae.Flux2VAE": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_dit.ZImageDiT": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -62,6 +62,8 @@ class DiskMap:
|
||||
param = self.files[file_id].get_tensor(name)
|
||||
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
|
||||
param = param.to(self.torch_dtype)
|
||||
if param.device == "cpu":
|
||||
param = param.clone()
|
||||
if isinstance(param, torch.Tensor):
|
||||
self.num_params += param.numel()
|
||||
if self.num_params > self.buffer_size:
|
||||
|
||||
@@ -51,7 +51,7 @@ class TimestepEmbedder(nn.Module):
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
||||
t_emb = self.mlp(t_freq.to(torch.bfloat16))
|
||||
return t_emb
|
||||
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ class ZImagePipeline(BasePipeline):
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
image = self.vae_decoder(inputs_shared["latents"])
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
Reference in New Issue
Block a user