From 157e0be49dcec3aa93df75d895618fecd0b883be Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 30 Jun 2025 11:00:10 +0800 Subject: [PATCH] kontext training --- examples/flux/README.md | 8 +++--- examples/flux/README_zh.md | 4 +-- .../model_training/full/FLUX.1-Kontext-dev.sh | 14 ++++++++++ .../validate_full/FLUX.1-Kontext-dev.py | 26 +++++++++++++++++++ 4 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 examples/flux/model_training/full/FLUX.1-Kontext-dev.sh create mode 100644 examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py diff --git a/examples/flux/README.md b/examples/flux/README.md index a04ecd3..3198fd5 100644 --- a/examples/flux/README.md +++ b/examples/flux/README.md @@ -1,6 +1,6 @@ # FLUX -[Switch to Chinese](./README_zh.md) +[切换到中文](./README_zh.md) FLUX is a series of image generation models open-sourced by Black-Forest-Labs. @@ -44,9 +44,9 @@ image.save("image.jpg") **Support for the new framework of the FLUX series models is under active development. Stay tuned!** | Model ID | Additional Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training | -|---------|------------------------|-----------|---------------|-------------------------------|---------------|--------------------------------| -| [black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](./model_inference/FLUX.1-dev.py) | [code](./model_training/full/FLUX.1-dev.sh) | [code](./model_training/validate_full/FLUX.1-dev.py) | [code](./model_training/lora/FLUX.1-dev.sh) | [code](./model_training/validate_lora/FLUX.1-dev.py) | -| [black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](./model_inference/FLUX.1-Kontext-dev.py) | | | [code](./model_training/lora/FLUX.1-Kontext-dev.sh) | [code](./model_training/validate_lora/FLUX.1-Kontext-dev.py) | +|-|-|-|-|-|-|-| +|[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)| +|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)| ## Model Inference diff --git a/examples/flux/README_zh.md b/examples/flux/README_zh.md index bb87279..1d9ebff 100644 --- a/examples/flux/README_zh.md +++ b/examples/flux/README_zh.md @@ -46,7 +46,7 @@ image.save("image.jpg") |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| |[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)| -|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|||[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)| +|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|[code](./model_training/full/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)| ## 模型推理 @@ -252,7 +252,7 @@ video,prompt,kontext_images image1.jpg,"a cat is sleeping",image1_reference.jpg ``` -额外输入若包含视频和图像文件,则需要在 `--data_file_keys` 参数中指定要解析的列名。可根据额外输入增加相应的列名,例如 `--data_file_keys "image,kontext_images"`。 +额外输入若包含视频和图像文件,则需要在 `--data_file_keys` 参数中指定要解析的列名。可根据额外输入增加相应的列名,例如 `--data_file_keys "image,kontext_images"`,同时启用 `--extra_inputs "kontext_images"`。 diff --git a/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh b/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh new file mode 100644 index 0000000..de1fa5d --- /dev/null +++ b/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \ + --data_file_keys "image,kontext_images" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 1 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-Kontext-dev_full" \ + --trainable_models "dit" \ + --extra_inputs "kontext_images" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py b/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py new file mode 100644 index 0000000..af3ee36 --- /dev/null +++ b/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py @@ -0,0 +1,26 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from diffsynth import load_state_dict +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +state_dict = load_state_dict("models/train/FLUX.1-Kontext-dev_full/epoch-0.safetensors") +pipe.dit.load_state_dict(state_dict) + +image = pipe( + prompt="Make the dog turn its head around.", + kontext_images=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)), + height=768, width=768, + seed=0 +) +image.save("image_FLUX.1-Kontext-dev_full.jpg")