From 03c8fd5e61dbeba7d1788f9388c4a90ae04e8278 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 29 Jul 2025 18:49:18 +0800 Subject: [PATCH] refine code --- .../Nexus-Gen-Editing.py | 16 +++++++++------- .../Nexus-Gen-Generation.py | 15 ++++++++------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py index 70a543f..313ce3c 100644 --- a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py @@ -2,7 +2,8 @@ import importlib import torch from PIL import Image from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig -from modelscope import snapshot_download, dataset_snapshot_download +from modelscope import dataset_snapshot_download + if importlib.util.find_spec("transformers") is None: raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") @@ -10,17 +11,18 @@ else: import transformers assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." -snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") + pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), - ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"), ], + nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), ) pipe.enable_vram_management() diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py index 053b22b..c865271 100644 --- a/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Generation.py @@ -1,7 +1,7 @@ import importlib import torch from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig -from modelscope import snapshot_download + if importlib.util.find_spec("transformers") is None: raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") @@ -9,17 +9,18 @@ else: import transformers assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." -snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") + pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), - ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), - ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"), ], + nexus_gen_processor_config=ModelConfig("DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor"), ) pipe.enable_vram_management()