From 8ef91b36728947e3563af2925a0a6c861a6309a3 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 29 Jul 2025 13:28:42 +0800 Subject: [PATCH] support training for eligen and nexusgen --- README.md | 4 +-- README_zh.md | 4 +-- diffsynth/models/nexus_gen.py | 11 +++--- diffsynth/pipelines/flux_image_new.py | 3 +- diffsynth/trainers/utils.py | 9 +++-- .../flux/model_inference/Nexus-Gen-Editing.py | 11 +++--- .../Nexus-Gen-Editing.py | 36 +++++++++++++++++++ .../full/FLUX.1-NexusGen-Edit.sh | 14 ++++++++ .../full/accelerate_config_zero2offload.yaml | 22 ++++++++++++ .../lora/FLUX.1-NexusGen-Edit.sh | 17 +++++++++ .../model_training/lora/FLUX.1-dev-EliGen.sh | 17 +++++++++ .../validate_full/Nexus-Gen-Editing.py | 28 +++++++++++++++ .../validate_lora/FLUX.1-dev-EliGen.py | 33 +++++++++++++++++ .../validate_lora/Nexus-Gen-Editing.py | 26 ++++++++++++++ 14 files changed, 218 insertions(+), 17 deletions(-) create mode 100644 examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py create mode 100644 examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh create mode 100644 examples/flux/model_training/full/accelerate_config_zero2offload.yaml create mode 100644 examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh create mode 100644 examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh create mode 100644 examples/flux/model_training/validate_full/Nexus-Gen-Editing.py create mode 100644 examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py create mode 100644 examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py diff --git a/README.md b/README.md index 11403a5..f592abb 100644 --- a/README.md +++ b/README.md @@ -96,12 +96,12 @@ image.save("image.jpg") |[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| |[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| -|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-||| +|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)| - +|[Nexus-Gen-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py)| diff --git a/README_zh.md b/README_zh.md index 650d2ec..dc1b514 100644 --- a/README_zh.md +++ b/README_zh.md @@ -98,12 +98,12 @@ image.save("image.jpg") |[FLUX.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)| |[FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)| |[FLUX.1-dev-InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](./examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](./examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)| -|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-||| +|[FLUX.1-dev-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](./examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](./examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](./examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)| |[FLUX.1-dev-LoRA-Encoder](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](./examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](./examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-| |[FLUX.1-dev-LoRA-Fusion-Preview](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](./examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-| |[Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](./examples/flux/model_inference/Step1X-Edit.py)|[code](./examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](./examples/flux/model_training/full/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](./examples/flux/model_training/lora/Step1X-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Step1X-Edit.py)| |[FLEX.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](./examples/flux/model_inference/FLEX.2-preview.py)|[code](./examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](./examples/flux/model_training/full/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](./examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](./examples/flux/model_training/validate_lora/FLEX.2-preview.py)| - +|[Nexus-Gen-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](./examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](./examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_full/Nexus-Gen-Editing.py)|[code](./examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh)|[code](./examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py)| ### Wan 系列 diff --git a/diffsynth/models/nexus_gen.py b/diffsynth/models/nexus_gen.py index f7a771e..31475c7 100644 --- a/diffsynth/models/nexus_gen.py +++ b/diffsynth/models/nexus_gen.py @@ -14,7 +14,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module): self.model = Qwen2_5_VLForConditionalGeneration(model_config) self.processor = Qwen2_5_VLProcessor.from_pretrained(model_path) - + @staticmethod def state_dict_converter(): return NexusGenAutoregressiveModelStateDictConverter() @@ -34,6 +34,7 @@ class NexusGenAutoregressiveModel(torch.nn.Module): return messages def get_generation_msg(self, instruction): + instruction = "Generate an image according to the following description: {}".format(instruction) messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: "}] return messages @@ -80,9 +81,10 @@ class NexusGenAutoregressiveModel(torch.nn.Module): ) input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds) - position_ids, _ = model.get_rope_index(inputs['input_ids'], - inputs['image_grid_thw'], - attention_mask=inputs['attention_mask']) + position_ids, _ = model.get_rope_index( + inputs['input_ids'], + inputs['image_grid_thw'], + attention_mask=inputs['attention_mask']) position_ids = position_ids.contiguous() outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True) output_image_embeddings = outputs.image_embeddings[:, :-1, :] @@ -97,4 +99,3 @@ class NexusGenAutoregressiveModelStateDictConverter: def from_civitai(self, state_dict): state_dict = {"model." + key: value for key, value in state_dict.items()} return state_dict - \ No newline at end of file diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 36d7922..8f9ec61 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -767,9 +767,10 @@ class FluxImageUnit_EntityControl(PipelineUnit): if eligen_entity_prompts is None or eligen_entity_masks is None: return inputs_shared, inputs_posi, inputs_nega pipe.load_models_to_device(self.onload_model_names) + eligen_enable_on_negative = inputs_shared.get("eligen_enable_on_negative", False) eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], - inputs_shared["t5_sequence_length"], inputs_shared["eligen_enable_on_negative"], inputs_shared["cfg_scale"]) + inputs_shared["t5_sequence_length"], eligen_enable_on_negative, inputs_shared["cfg_scale"]) inputs_posi.update(eligen_kwargs_posi) if inputs_shared.get("cfg_scale", 1.0) != 1.0: inputs_nega.update(eligen_kwargs_nega) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index b171857..07e3664 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -120,8 +120,13 @@ class ImageDataset(torch.utils.data.Dataset): data = self.data[data_id % len(self.data)].copy() for key in self.data_file_keys: if key in data: - path = os.path.join(self.base_path, data[key]) - data[key] = self.load_data(path) + if isinstance(data[key], list): + print(f"Loading multiple files for key '{key}'.") + path = [os.path.join(self.base_path, p) for p in data[key]] + data[key] = [self.load_data(p) for p in path] + else: + path = os.path.join(self.base_path, data[key]) + data[key] = self.load_data(path) if data[key] is None: warnings.warn(f"cannot load file {data[key]}.") return None diff --git a/examples/flux/model_inference/Nexus-Gen-Editing.py b/examples/flux/model_inference/Nexus-Gen-Editing.py index 603ac33..f24f0c0 100644 --- a/examples/flux/model_inference/Nexus-Gen-Editing.py +++ b/examples/flux/model_inference/Nexus-Gen-Editing.py @@ -2,7 +2,7 @@ import importlib import torch from PIL import Image from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig -from modelscope import snapshot_download +from modelscope import snapshot_download, dataset_snapshot_download if importlib.util.find_spec("transformers") is None: raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") @@ -23,12 +23,13 @@ pipe = FluxImagePipeline.from_pretrained( ], ) -prompt = "给猫加一副太阳镜" -ref_image = Image.open("cat.png").convert("RGB") +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg") +ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB") +prompt = "Add a crown." image = pipe( prompt=prompt, negative_prompt="", - seed=0, cfg_scale=1.0, num_inference_steps=50, + seed=42, cfg_scale=2.0, num_inference_steps=50, nexus_gen_reference_image=ref_image, height=512, width=512, ) -image.save("cat_glasses.jpg") +image.save("cat_crown.jpg") diff --git a/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py new file mode 100644 index 0000000..70a543f --- /dev/null +++ b/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py @@ -0,0 +1,36 @@ +import importlib +import torch +from PIL import Image +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from modelscope import snapshot_download, dataset_snapshot_download + +if importlib.util.find_spec("transformers") is None: + raise ImportError("You are using Nexus-GenV2. It depends on transformers, which is not installed. Please install it with `pip install transformers==4.49.0`.") +else: + import transformers + assert transformers.__version__ == "4.49.0", "Nexus-GenV2 requires transformers==4.49.0, please install it with `pip install transformers==4.49.0`." + +snapshot_download("DiffSynth-Studio/Nexus-GenV2", local_dir="models/DiffSynth-Studio/Nexus-GenV2") +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/nexusgen/cat.jpg") +ref_image = Image.open("data/examples/nexusgen/cat.jpg").convert("RGB") +prompt = "Add a crown." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=2.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("cat_crown.jpg") diff --git a/examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh b/examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh new file mode 100644 index 0000000..ab1c324 --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-NexusGen-Edit.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_nexusgen_edit.csv \ + --data_file_keys "image,nexus_gen_reference_image" \ + --max_pixels 262144 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-NexusGen-Edit_full" \ + --trainable_models "dit" \ + --extra_inputs "nexus_gen_reference_image" \ + --use_gradient_checkpointing_offload diff --git a/examples/flux/model_training/full/accelerate_config_zero2offload.yaml b/examples/flux/model_training/full/accelerate_config_zero2offload.yaml new file mode 100644 index 0000000..8a75f3d --- /dev/null +++ b/examples/flux/model_training/full/accelerate_config_zero2offload.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: 'cpu' + offload_param_device: 'cpu' + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh b/examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh new file mode 100644 index 0000000..3e6eac1 --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-NexusGen-Edit.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_nexusgen_edit.csv \ + --data_file_keys "image,nexus_gen_reference_image" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "DiffSynth-Studio/Nexus-GenV2:model*.safetensors,DiffSynth-Studio/Nexus-GenV2:edit_decoder.bin,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-NexusGen-Edit_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --extra_inputs "nexus_gen_reference_image" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh b/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh new file mode 100644 index 0000000..10a18e0 --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_eligen.json \ + --data_file_keys "image,eligen_entity_masks" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-dev-EliGen_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --extra_inputs "eligen_entity_masks,eligen_entity_prompts" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/validate_full/Nexus-Gen-Editing.py b/examples/flux/model_training/validate_full/Nexus-Gen-Editing.py new file mode 100644 index 0000000..5f7a2d2 --- /dev/null +++ b/examples/flux/model_training/validate_full/Nexus-Gen-Editing.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-NexusGen-Edit_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict) + +ref_image = Image.open("data/example_image_dataset/nexus_gen/image_1.png").convert("RGB") +prompt = "Add a pair of sunglasses." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=2.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("NexusGen-Edit_full.jpg") diff --git a/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py b/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py new file mode 100644 index 0000000..7df3db2 --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +pipe.load_lora(pipe.dit, "models/train/FLUX.1-dev-EliGen_lora/epoch-4.safetensors", alpha=1) + +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +masks = [Image.open(f"data/example_image_dataset/eligen/{i}.png").convert('RGB') for i in range(len(entity_prompts))] +# generate image +image = pipe( + prompt=global_prompt, + cfg_scale=1.0, + num_inference_steps=50, + embedded_guidance=3.5, + seed=42, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, +) +image.save(f"EliGen_lora.png") diff --git a/examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py b/examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py new file mode 100644 index 0000000..21c376f --- /dev/null +++ b/examples/flux/model_training/validate_lora/Nexus-Gen-Editing.py @@ -0,0 +1,26 @@ +import torch +from PIL import Image +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-NexusGen-Edit_lora/epoch-4.safetensors", alpha=1) + +ref_image = Image.open("data/example_image_dataset/nexus_gen/image_1.png").convert("RGB") +prompt = "Add a pair of sunglasses." +image = pipe( + prompt=prompt, negative_prompt="", + seed=42, cfg_scale=1.0, num_inference_steps=50, + nexus_gen_reference_image=ref_image, + height=512, width=512, +) +image.save("NexusGen-Edit_lora.jpg")