diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index 8e7dde0..219983d 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -155,4 +155,24 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", }, + "diffsynth.models.flux2_text_encoder.Flux2TextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.flux2_vae.Flux2VAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_text_encoder.ZImageTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.z_image_dit.ZImageDiT": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } diff --git a/diffsynth/core/vram/disk_map.py b/diffsynth/core/vram/disk_map.py index 6f0b6ea..9f80068 100644 --- a/diffsynth/core/vram/disk_map.py +++ b/diffsynth/core/vram/disk_map.py @@ -62,6 +62,8 @@ class DiskMap: param = self.files[file_id].get_tensor(name) if self.torch_dtype is not None and isinstance(param, torch.Tensor): param = param.to(self.torch_dtype) + if param.device == "cpu": + param = param.clone() if isinstance(param, torch.Tensor): self.num_params += param.numel() if self.num_params > self.buffer_size: diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index 661ab2b..40dffaa 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -51,7 +51,7 @@ class TimestepEmbedder(nn.Module): def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) + t_emb = self.mlp(t_freq.to(torch.bfloat16)) return t_emb diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index b6ad72c..f87254f 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -119,7 +119,7 @@ class ZImagePipeline(BasePipeline): inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) # Decode - self.load_models_to_device(['vae']) + self.load_models_to_device(['vae_decoder']) image = self.vae_decoder(inputs_shared["latents"]) image = self.vae_output_to_image(image) self.load_models_to_device([]) diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-dev.py b/examples/flux2/model_inference_low_vram/FLUX.2-dev.py new file mode 100644 index 0000000..1542e53 --- /dev/null +++ b/examples/flux2/model_inference_low_vram/FLUX.2-dev.py @@ -0,0 +1,28 @@ +from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = Flux2ImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom." +image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50) +image.save("image_FLUX.2-dev.jpg") diff --git a/examples/z_image/model_inference/Z-Image-Turbo.py b/examples/z_image/model_inference/Z-Image-Turbo.py index 1a61f22..afa0c4d 100644 --- a/examples/z_image/model_inference/Z-Image-Turbo.py +++ b/examples/z_image/model_inference/Z-Image-Turbo.py @@ -14,4 +14,4 @@ pipe = ZImagePipeline.from_pretrained( ) prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." image = pipe(prompt=prompt, seed=42, rand_device="cuda") -image.save("image.jpg") +image.save("image_Z-Image-Turbo.jpg") diff --git a/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py b/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py new file mode 100644 index 0000000..6ad8f42 --- /dev/null +++ b/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py @@ -0,0 +1,38 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +vram_config_cpu_offload = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config_cpu_offload), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." +image = pipe(prompt=prompt, seed=42, rand_device="cuda") +image.save("image_Z-Image-Turbo.jpg")