From 0336551544f372fcf94ed7c6bcff936c8ce754de Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 21 Nov 2025 20:26:49 +0800 Subject: [PATCH] bug fix --- diffsynth/diffusion/training_module.py | 2 +- diffsynth/pipelines/flux_image.py | 40 +++++++++++++++---- .../full/FLUX.1-dev-IP-Adapter.sh | 2 +- .../flux/model_training/full/Step1X-Edit.sh | 2 +- .../lora/FLUX.1-dev-IP-Adapter.sh | 2 +- .../flux/model_training/lora/Step1X-Edit.sh | 2 +- .../validate_full/FLUX.1-dev-IP-Adapter.py | 2 +- .../validate_lora/FLEX.2-preview.py | 2 +- .../validate_lora/FLUX.1-dev-IP-Adapter.py | 2 +- .../validate_lora/Step1X-Edit.py | 2 +- .../Qwen-Image-Distill-DMD2.py | 3 +- .../Qwen-Image-Distill-DMD2.py | 3 +- examples/{test/run.py => unit_test.py} | 2 +- 13 files changed, 46 insertions(+), 20 deletions(-) rename examples/{test/run.py => unit_test.py} (100%) diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py index 983984b..034336d 100644 --- a/diffsynth/diffusion/training_module.py +++ b/diffsynth/diffusion/training_module.py @@ -202,7 +202,7 @@ class DiffusionTrainingModule(torch.nn.Module): if name not in controlnet_inputs: controlnet_inputs[name] = {} controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input] - break + break else: inputs_shared[extra_input] = data[extra_input] for name, params in controlnet_inputs.items(): diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index d5fc30e..2ef2617 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -145,7 +145,8 @@ class FluxImagePipeline(BasePipeline): value_controllers = model_pool.fetch_model("flux_value_controller") if value_controllers is not None: pipe.value_controller = MultiValueEncoder(value_controllers) - pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled + if hasattr(pipe.value_controller.encoders[0], "vram_management_enabled"): + pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled controlnets = model_pool.fetch_model("flux_controlnet") if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets) pipe.ipadapter = model_pool.fetch_model("flux_ipadapter") @@ -295,7 +296,7 @@ class FluxImagePipeline(BasePipeline): class FluxImageUnit_ShapeChecker(PipelineUnit): def __init__(self): - super().__init__(input_params=("height", "width")) + super().__init__(input_params=("height", "width"), output_params=("height", "width")) def process(self, pipe: FluxImagePipeline, height, width): height, width = pipe.check_resize_height_width(height, width) @@ -305,7 +306,7 @@ class FluxImageUnit_ShapeChecker(PipelineUnit): class FluxImageUnit_NoiseInitializer(PipelineUnit): def __init__(self): - super().__init__(input_params=("height", "width", "seed", "rand_device")) + super().__init__(input_params=("height", "width", "seed", "rand_device"), output_params=("noise",)) def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device): noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device) @@ -317,6 +318,7 @@ class FluxImageUnit_InputImageEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"), + output_params=("latents", "input_latents"), onload_model_names=("vae_encoder",) ) @@ -341,6 +343,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit): input_params_posi={"prompt": "prompt", "positive": "positive"}, input_params_nega={"prompt": "negative_prompt", "positive": "positive"}, input_params=("t5_sequence_length",), + output_params=("prompt_emb", "pooled_prompt_emb", "text_ids"), onload_model_names=("text_encoder_1", "text_encoder_2") ) @@ -396,7 +399,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit): class FluxImageUnit_ImageIDs(PipelineUnit): def __init__(self): - super().__init__(input_params=("latents",)) + super().__init__(input_params=("latents",), output_params=("image_ids",)) def process(self, pipe: FluxImagePipeline, latents): latent_image_ids = pipe.dit.prepare_image_ids(latents) @@ -406,7 +409,7 @@ class FluxImageUnit_ImageIDs(PipelineUnit): class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit): def __init__(self): - super().__init__(input_params=("embedded_guidance", "latents")) + super().__init__(input_params=("embedded_guidance", "latents"), output_params=("guidance",)) def process(self, pipe: FluxImagePipeline, embedded_guidance, latents): guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) @@ -416,7 +419,11 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit): class FluxImageUnit_Kontext(PipelineUnit): def __init__(self): - super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride")) + super().__init__( + input_params=("kontext_images", "tiled", "tile_size", "tile_stride"), + output_params=("kontext_latents", "kontext_image_ids"), + onload_model_names=("vae_encoder",) + ) def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride): if kontext_images is None: @@ -444,6 +451,7 @@ class FluxImageUnit_ControlNet(PipelineUnit): def __init__(self): super().__init__( input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"), + output_params=("controlnet_conditionings",), onload_model_names=("vae_encoder",) ) @@ -486,6 +494,8 @@ class FluxImageUnit_IPAdapter(PipelineUnit): def __init__(self): super().__init__( take_over=True, + input_params=("ipadapter_images", "ipadapter_scale"), + output_params=("ipadapter_kwargs_list",), onload_model_names=("ipadapter_image_encoder", "ipadapter") ) @@ -513,6 +523,8 @@ class FluxImageUnit_EntityControl(PipelineUnit): def __init__(self): super().__init__( take_over=True, + input_params=("eligen_entity_prompts", "eligen_entity_masks", "eligen_enable_on_negative", "width", "height", "t5_sequence_length", "cfg_scale"), + output_params=("entity_prompt_emb", "entity_masks"), onload_model_names=("text_encoder_1", "text_encoder_2") ) @@ -603,6 +615,8 @@ class FluxImageUnit_NexusGen(PipelineUnit): def __init__(self): super().__init__( take_over=True, + input_params=("nexus_gen_reference_image", "prompt", "latents"), + output_params=("prompt_emb", "text_ids"), onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"), ) @@ -655,7 +669,12 @@ class FluxImageUnit_NexusGen(PipelineUnit): class FluxImageUnit_Step1x(PipelineUnit): def __init__(self): - super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder")) + super().__init__( + take_over=True, + input_params=("step1x_reference_image", "prompt", "negative_prompt"), + output_params=("step1x_llm_embedding", "step1x_mask", "step1x_reference_latents"), + onload_model_names=("qwenvl","vae_encoder") + ) def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict): image = inputs_shared.get("step1x_reference_image",None) @@ -678,7 +697,7 @@ class FluxImageUnit_Step1x(PipelineUnit): class FluxImageUnit_TeaCache(PipelineUnit): def __init__(self): - super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh")) + super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"), output_params=("tea_cache",)) def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh): if tea_cache_l1_thresh is None: @@ -690,6 +709,7 @@ class FluxImageUnit_Flex(PipelineUnit): def __init__(self): super().__init__( input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"), + output_params=("flex_condition", "flex_uncondition", "flex_control_stop_timestep"), onload_model_names=("vae_encoder",) ) @@ -728,6 +748,7 @@ class FluxImageUnit_InfiniteYou(PipelineUnit): def __init__(self): super().__init__( input_params=("infinityou_id_image", "infinityou_guidance"), + output_params=("id_emb", "infinityou_guidance"), onload_model_names=("infinityou_processor",) ) @@ -747,6 +768,7 @@ class FluxImageUnit_ValueControl(PipelineUnit): input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, input_params=("value_controller_inputs",), + output_params=("prompt_emb", "text_ids"), onload_model_names=("value_controller",) ) @@ -825,6 +847,8 @@ class FluxImageUnit_LoRAEncode(PipelineUnit): def __init__(self): super().__init__( take_over=True, + input_params=("lora_encoder_inputs", "lora_encoder_scale"), + output_params=("prompt_emb", "text_ids"), onload_model_names=("lora_encoder",) ) diff --git a/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh b/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh index f19f1e7..6db5e79 100644 --- a/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh +++ b/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh @@ -4,7 +4,7 @@ accelerate launch examples/flux/model_training/train.py \ --data_file_keys "image,ipadapter_images" \ --max_pixels 1048576 \ --dataset_repeat 100 \ - --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:model.safetensors" \ --learning_rate 1e-5 \ --num_epochs 1 \ --remove_prefix_in_ckpt "pipe.ipadapter." \ diff --git a/examples/flux/model_training/full/Step1X-Edit.sh b/examples/flux/model_training/full/Step1X-Edit.sh index 98c45ce..03ddfda 100644 --- a/examples/flux/model_training/full/Step1X-Edit.sh +++ b/examples/flux/model_training/full/Step1X-Edit.sh @@ -4,7 +4,7 @@ accelerate launch --config_file examples/flux/model_training/full/accelerate_con --data_file_keys "image,step1x_reference_image" \ --max_pixels 1048576 \ --dataset_repeat 400 \ - --model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \ + --model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:model-*.safetensors,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \ --learning_rate 1e-5 \ --num_epochs 1 \ --remove_prefix_in_ckpt "pipe.dit." \ diff --git a/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh b/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh index 98617d0..e110075 100644 --- a/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh +++ b/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh @@ -4,7 +4,7 @@ accelerate launch examples/flux/model_training/train.py \ --data_file_keys "image,ipadapter_images" \ --max_pixels 1048576 \ --dataset_repeat 50 \ - --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:" \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors,InstantX/FLUX.1-dev-IP-Adapter:ip-adapter.bin,google/siglip-so400m-patch14-384:model.safetensors" \ --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ diff --git a/examples/flux/model_training/lora/Step1X-Edit.sh b/examples/flux/model_training/lora/Step1X-Edit.sh index 01ac260..a7f1d8f 100644 --- a/examples/flux/model_training/lora/Step1X-Edit.sh +++ b/examples/flux/model_training/lora/Step1X-Edit.sh @@ -4,7 +4,7 @@ accelerate launch examples/flux/model_training/train.py \ --data_file_keys "image,step1x_reference_image" \ --max_pixels 1048576 \ --dataset_repeat 50 \ - --model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \ + --model_id_with_origin_paths "Qwen/Qwen2.5-VL-7B-Instruct:model-*.safetensors,stepfun-ai/Step1X-Edit:step1x-edit-i1258.safetensors,stepfun-ai/Step1X-Edit:vae.safetensors" \ --learning_rate 1e-4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.dit." \ diff --git a/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py b/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py index 10ed877..7c15ded 100644 --- a/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py +++ b/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py @@ -13,7 +13,7 @@ pipe = FluxImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"), - ModelConfig(model_id="google/siglip-so400m-patch14-384"), + ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors"), ], ) state_dict = load_state_dict("models/train/FLUX.1-dev-IP-Adapter_full/epoch-0.safetensors") diff --git a/examples/flux/model_training/validate_lora/FLEX.2-preview.py b/examples/flux/model_training/validate_lora/FLEX.2-preview.py index a905918..6a6a60d 100644 --- a/examples/flux/model_training/validate_lora/FLEX.2-preview.py +++ b/examples/flux/model_training/validate_lora/FLEX.2-preview.py @@ -6,7 +6,7 @@ pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"), + ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py b/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py index afe182f..31c295b 100644 --- a/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py @@ -12,7 +12,7 @@ pipe = FluxImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"), - ModelConfig(model_id="google/siglip-so400m-patch14-384"), + ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors"), ], ) pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-IP-Adapter_lora/epoch-4.safetensors", alpha=1) diff --git a/examples/flux/model_training/validate_lora/Step1X-Edit.py b/examples/flux/model_training/validate_lora/Step1X-Edit.py index 6b50d81..e89ff98 100644 --- a/examples/flux/model_training/validate_lora/Step1X-Edit.py +++ b/examples/flux/model_training/validate_lora/Step1X-Edit.py @@ -7,7 +7,7 @@ pipe = FluxImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct"), + ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern="model-*.safetensors"), ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"), ], diff --git a/examples/qwen_image/model_inference/Qwen-Image-Distill-DMD2.py b/examples/qwen_image/model_inference/Qwen-Image-Distill-DMD2.py index ab7ecda..007538f 100644 --- a/examples/qwen_image/model_inference/Qwen-Image-Distill-DMD2.py +++ b/examples/qwen_image/model_inference/Qwen-Image-Distill-DMD2.py @@ -1,4 +1,5 @@ -from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, load_state_dict +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.core import load_state_dict from modelscope import snapshot_download import torch, math diff --git a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-DMD2.py b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-DMD2.py index 2cebc24..6b95667 100644 --- a/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-DMD2.py +++ b/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-DMD2.py @@ -1,4 +1,5 @@ -from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, load_state_dict +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.core import load_state_dict from modelscope import snapshot_download import torch, math diff --git a/examples/test/run.py b/examples/unit_test.py similarity index 100% rename from examples/test/run.py rename to examples/unit_test.py index 93d4488..364af47 100644 --- a/examples/test/run.py +++ b/examples/unit_test.py @@ -110,5 +110,5 @@ def test_flux(): if __name__ == "__main__": test_qwen_image() - test_wan() test_flux() + test_wan()