diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index b736625..6e92f3b 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -37,6 +37,7 @@ class Flux2ImagePipeline(BasePipeline): Flux2Unit_Qwen3PromptEmbedder(), Flux2Unit_NoiseInitializer(), Flux2Unit_InputImageEmbedder(), + Flux2Unit_EditImageEmbedder(), Flux2Unit_ImageIDs(), ] self.model_fn = model_fn_flux2 @@ -79,6 +80,9 @@ class Flux2ImagePipeline(BasePipeline): # Image input_image: Image.Image = None, denoising_strength: float = 1.0, + # Edit + edit_image: Union[Image.Image, List[Image.Image]] = None, + edit_image_auto_resize: bool = True, # Shape height: int = 1024, width: int = 1024, @@ -102,6 +106,7 @@ class Flux2ImagePipeline(BasePipeline): inputs_shared = { "cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance, "input_image": input_image, "denoising_strength": denoising_strength, + "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "height": height, "width": width, "seed": seed, "rand_device": rand_device, "num_inference_steps": num_inference_steps, @@ -456,6 +461,64 @@ class Flux2Unit_InputImageEmbedder(PipelineUnit): return {"latents": latents, "input_latents": input_latents} +class Flux2Unit_EditImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("edit_image", "edit_image_auto_resize"), + output_params=("edit_latents", "edit_image_ids"), + onload_model_names=("vae",) + ) + + def calculate_dimensions(self, target_area, ratio): + import math + width = math.sqrt(target_area * ratio) + height = width / ratio + width = round(width / 32) * 32 + height = round(height / 32) * 32 + return width, height + + def edit_image_auto_resize(self, edit_image): + calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) + return edit_image.resize((calculated_width, calculated_height)) + + def process_image_ids(self, image_latents, scale=10): + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize): + if edit_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if isinstance(edit_image, Image.Image): + edit_image = [edit_image] + resized_edit_image, edit_latents = [], [] + for image in edit_image: + # Preprocess + if edit_image_auto_resize is None or edit_image_auto_resize: + image = self.edit_image_auto_resize(image) + resized_edit_image.append(image) + # Encode + image = pipe.preprocess_image(image) + latents = pipe.vae.encode(image) + edit_latents.append(latents) + edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device) + edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1) + return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids} + + class Flux2Unit_ImageIDs(PipelineUnit): def __init__(self): super().__init__( @@ -490,10 +553,17 @@ def model_fn_flux2( prompt_embeds=None, text_ids=None, image_ids=None, + edit_latents=None, + edit_image_ids=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, ): + image_seq_len = latents.shape[1] + if edit_latents is not None: + image_seq_len = latents.shape[1] + latents = torch.concat([latents, edit_latents], dim=1) + image_ids = torch.concat([image_ids, edit_image_ids], dim=1) embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) model_output = dit( hidden_states=latents, @@ -505,4 +575,5 @@ def model_fn_flux2( use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) + model_output = model_output[:, :image_seq_len] return model_output diff --git a/examples/flux2/model_inference/FLUX.2-klein-4B.py b/examples/flux2/model_inference/FLUX.2-klein-4B.py index fbfe33d..2175901 100644 --- a/examples/flux2/model_inference/FLUX.2-klein-4B.py +++ b/examples/flux2/model_inference/FLUX.2-klein-4B.py @@ -15,3 +15,7 @@ pipe = Flux2ImagePipeline.from_pretrained( prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) image.save("image_FLUX.2-klein-4B.jpg") + +prompt = "change the color of the clothes to red" +image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4) +image.save("image_edit_FLUX.2-klein-4B.jpg") diff --git a/examples/flux2/model_inference/FLUX.2-klein-9B.py b/examples/flux2/model_inference/FLUX.2-klein-9B.py index 2abf0e7..b20fc2c 100644 --- a/examples/flux2/model_inference/FLUX.2-klein-9B.py +++ b/examples/flux2/model_inference/FLUX.2-klein-9B.py @@ -15,3 +15,7 @@ pipe = Flux2ImagePipeline.from_pretrained( prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) image.save("image_FLUX.2-klein-9B.jpg") + +prompt = "change the color of the clothes to red" +image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4) +image.save("image_edit_FLUX.2-klein-9B.jpg") diff --git a/examples/flux2/model_inference/FLUX.2-klein-base-4B.py b/examples/flux2/model_inference/FLUX.2-klein-base-4B.py index 8ce4521..064e035 100644 --- a/examples/flux2/model_inference/FLUX.2-klein-base-4B.py +++ b/examples/flux2/model_inference/FLUX.2-klein-base-4B.py @@ -15,3 +15,7 @@ pipe = Flux2ImagePipeline.from_pretrained( prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4) image.save("image_FLUX.2-klein-base-4B.jpg") + +prompt = "change the color of the clothes to red" +image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image_edit_FLUX.2-klein-base-4B.jpg") diff --git a/examples/flux2/model_inference/FLUX.2-klein-base-9B.py b/examples/flux2/model_inference/FLUX.2-klein-base-9B.py index aa7193f..e2e6065 100644 --- a/examples/flux2/model_inference/FLUX.2-klein-base-9B.py +++ b/examples/flux2/model_inference/FLUX.2-klein-base-9B.py @@ -15,3 +15,7 @@ pipe = Flux2ImagePipeline.from_pretrained( prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4) image.save("image_FLUX.2-klein-base-9B.jpg") + +prompt = "change the color of the clothes to red" +image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image_edit_FLUX.2-klein-base-9B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py index 019f58e..dbdc8e4 100644 --- a/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py @@ -25,3 +25,7 @@ pipe = Flux2ImagePipeline.from_pretrained( prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) image.save("image_FLUX.2-klein-4B.jpg") + +prompt = "change the color of the clothes to red" +image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4) +image.save("image_edit_FLUX.2-klein-4B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py index b629c94..dc7b9a7 100644 --- a/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py @@ -25,3 +25,7 @@ pipe = Flux2ImagePipeline.from_pretrained( prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4) image.save("image_FLUX.2-klein-9B.jpg") + +prompt = "change the color of the clothes to red" +image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4) +image.save("image_edit_FLUX.2-klein-9B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py index 733a006..5a1517f 100644 --- a/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py @@ -25,3 +25,7 @@ pipe = Flux2ImagePipeline.from_pretrained( prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4) image.save("image_FLUX.2-klein-base-4B.jpg") + +prompt = "change the color of the clothes to red" +image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image_edit_FLUX.2-klein-base-4B.jpg") diff --git a/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py index d5f5f80..e0df8a6 100644 --- a/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py +++ b/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py @@ -25,3 +25,7 @@ pipe = Flux2ImagePipeline.from_pretrained( prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles." image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4) image.save("image_FLUX.2-klein-base-9B.jpg") + +prompt = "change the color of the clothes to red" +image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4) +image.save("image_edit_FLUX.2-klein-base-9B.jpg") diff --git a/examples/flux2/model_training/full/FLUX.2-klein-4B.sh b/examples/flux2/model_training/full/FLUX.2-klein-4B.sh index 4fa46da..9f9a206 100644 --- a/examples/flux2/model_training/full/FLUX.2-klein-4B.sh +++ b/examples/flux2/model_training/full/FLUX.2-klein-4B.sh @@ -11,3 +11,20 @@ accelerate launch examples/flux2/model_training/train.py \ --output_path "./models/train/FLUX.2-klein-4B_full" \ --trainable_models "dit" \ --use_gradient_checkpointing + +# Edit +# accelerate launch examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ +# --learning_rate 1e-5 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-4B_full" \ +# --trainable_models "dit" \ +# --use_gradient_checkpointing diff --git a/examples/flux2/model_training/full/FLUX.2-klein-9B.sh b/examples/flux2/model_training/full/FLUX.2-klein-9B.sh index c89e8f0..2cc439b 100644 --- a/examples/flux2/model_training/full/FLUX.2-klein-9B.sh +++ b/examples/flux2/model_training/full/FLUX.2-klein-9B.sh @@ -1,4 +1,5 @@ -accelerate launch examples/flux2/model_training/train.py \ +# This script is tested on 8*A100 +accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata.csv \ --max_pixels 1048576 \ @@ -11,3 +12,20 @@ accelerate launch examples/flux2/model_training/train.py \ --output_path "./models/train/FLUX.2-klein-9B_full" \ --trainable_models "dit" \ --use_gradient_checkpointing + +# Edit +# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ +# --learning_rate 1e-5 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-9B_full" \ +# --trainable_models "dit" \ +# --use_gradient_checkpointing diff --git a/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh b/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh index 0862391..6590806 100644 --- a/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh +++ b/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh @@ -11,3 +11,20 @@ accelerate launch examples/flux2/model_training/train.py \ --output_path "./models/train/FLUX.2-klein-base-4B_full" \ --trainable_models "dit" \ --use_gradient_checkpointing + +# Edit +# accelerate launch examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ +# --learning_rate 1e-5 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-base-4B_full" \ +# --trainable_models "dit" \ +# --use_gradient_checkpointing \ No newline at end of file diff --git a/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh b/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh index d33a21f..4fb8064 100644 --- a/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh +++ b/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh @@ -1,4 +1,5 @@ -accelerate launch examples/flux2/model_training/train.py \ +# This script is tested on 8*A100 +accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata.csv \ --max_pixels 1048576 \ @@ -11,3 +12,20 @@ accelerate launch examples/flux2/model_training/train.py \ --output_path "./models/train/FLUX.2-klein-base-9B_full" \ --trainable_models "dit" \ --use_gradient_checkpointing + +# Edit +# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ +# --learning_rate 1e-5 \ +# --num_epochs 2 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-base-9B_full" \ +# --trainable_models "dit" \ +# --use_gradient_checkpointing diff --git a/examples/flux2/model_training/full/accelerate_config.yaml b/examples/flux2/model_training/full/accelerate_config.yaml new file mode 100644 index 0000000..83280f7 --- /dev/null +++ b/examples/flux2/model_training/full/accelerate_config.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh index 8f897cc..0e89205 100644 --- a/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh +++ b/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh @@ -13,3 +13,22 @@ accelerate launch examples/flux2/model_training/train.py \ --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \ --lora_rank 32 \ --use_gradient_checkpointing + +# Edit +# accelerate launch examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ +# --learning_rate 1e-4 \ +# --num_epochs 5 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-4B_lora" \ +# --lora_base_model "dit" \ +# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \ +# --lora_rank 32 \ +# --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh index 258c5fe..26265f1 100644 --- a/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh +++ b/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh @@ -13,3 +13,22 @@ accelerate launch examples/flux2/model_training/train.py \ --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \ --lora_rank 32 \ --use_gradient_checkpointing + +# Edit +# accelerate launch examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ +# --learning_rate 1e-4 \ +# --num_epochs 5 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-9B_lora" \ +# --lora_base_model "dit" \ +# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \ +# --lora_rank 32 \ +# --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh index e7f636e..f23e71f 100644 --- a/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh +++ b/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh @@ -13,3 +13,22 @@ accelerate launch examples/flux2/model_training/train.py \ --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \ --lora_rank 32 \ --use_gradient_checkpointing + +# Edit +# accelerate launch examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \ +# --learning_rate 1e-4 \ +# --num_epochs 5 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-base-4B_lora" \ +# --lora_base_model "dit" \ +# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \ +# --lora_rank 32 \ +# --use_gradient_checkpointing diff --git a/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh b/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh index d4f65df..d714b0e 100644 --- a/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh +++ b/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh @@ -13,3 +13,22 @@ accelerate launch examples/flux2/model_training/train.py \ --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \ --lora_rank 32 \ --use_gradient_checkpointing + +# Edit +# accelerate launch examples/flux2/model_training/train.py \ +# --dataset_base_path data/example_image_dataset \ +# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \ +# --data_file_keys "image,edit_image" \ +# --extra_inputs "edit_image" \ +# --max_pixels 1048576 \ +# --dataset_repeat 50 \ +# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \ +# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \ +# --learning_rate 1e-4 \ +# --num_epochs 5 \ +# --remove_prefix_in_ckpt "pipe.dit." \ +# --output_path "./models/train/FLUX.2-klein-base-9B_lora" \ +# --lora_base_model "dit" \ +# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \ +# --lora_rank 32 \ +# --use_gradient_checkpointing \ No newline at end of file