From 5e95a85281134c7815abfe562b527e81fa97a51b Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 10 Nov 2025 17:12:55 +0800 Subject: [PATCH] update doc --- diffsynth/pipelines/qwen_image.py | 3 - docs/Developer_Guide/Building_a_Pipeline.md | 1 + docs/Model_Details/Qwen-Image.md | 122 ++++++++++++++++++ docs/QA.md | 19 ++- docs/README.md | 8 +- docs/Training/Differential_LoRA.md | 38 ++++++ docs/Training/Split_Training.md | 95 ++++++++++++++ docs/Training/Supervised_Fine_Tuning.md | 2 +- .../differential_training/Qwen-Image-LoRA.sh | 3 +- 9 files changed, 281 insertions(+), 10 deletions(-) create mode 100644 docs/Training/Differential_LoRA.md diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 1849b08..6e21fc6 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -113,8 +113,6 @@ class QwenImagePipeline(BasePipeline): edit_rope_interpolation: bool = False, # In-context control context_image: Image.Image = None, - # FP8 - enable_fp8_attention: bool = False, # Tile tiled: bool = False, tile_size: int = 128, @@ -138,7 +136,6 @@ class QwenImagePipeline(BasePipeline): "inpaint_mask": inpaint_mask, "inpaint_blur_size": inpaint_blur_size, "inpaint_blur_sigma": inpaint_blur_sigma, "height": height, "width": width, "seed": seed, "rand_device": rand_device, - "enable_fp8_attention": enable_fp8_attention, "num_inference_steps": num_inference_steps, "blockwise_controlnet_inputs": blockwise_controlnet_inputs, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, diff --git a/docs/Developer_Guide/Building_a_Pipeline.md b/docs/Developer_Guide/Building_a_Pipeline.md index f58b3b1..3749118 100644 --- a/docs/Developer_Guide/Building_a_Pipeline.md +++ b/docs/Developer_Guide/Building_a_Pipeline.md @@ -231,6 +231,7 @@ class QwenImageUnit_EntityControl(PipelineUnit): * 缺省兜底:可选功能的 `unit` 输入参数默认为 `None`,而不是 `False` 或其他数值,请对此默认值进行兜底处理。 * 参数触发:部分 Adapter 模型可能是未被加载的,例如 ControlNet,对应的 `unit` 应当以参数输入是否为 `None` 来控制触发,而不是以模型是否被加载来控制触发。例如当用户输入了 `controlnet_image` 但没有加载 ControlNet 模型时,代码应当给出报错,而不是忽略这些输入参数继续执行。 +* 简洁优先:尽可能使用直接模式,仅当功能无法实现时,使用接管模式。 * 显存高效:在 `unit` 中调用模型时,请使用 `pipe.load_models_to_device(self.onload_model_names)` 激活对应的模型,请不要调用 `onload_model_names` 之外的其他模型,`unit` 计算完成后,请不要使用 `pipe.load_models_to_device([])` 手动释放显存。 > Q: 部分参数并未在推理过程中调用,例如 `output_params`,是否仍有必要配置? diff --git a/docs/Model_Details/Qwen-Image.md b/docs/Model_Details/Qwen-Image.md index e69de29..1671e67 100644 --- a/docs/Model_Details/Qwen-Image.md +++ b/docs/Model_Details/Qwen-Image.md @@ -0,0 +1,122 @@ +# Qwen-Image + +![Image](https://github.com/user-attachments/assets/738078d8-8749-4a53-a046-571861541924) + +Qwen-Image 是由阿里巴巴通义实验室开源的图像生成模型。 + +
+ +模型血缘 + +```mermaid +graph LR; + Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit; + Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509; + Qwen/Qwen-Image-->EliGen-Series; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen; + DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2; + EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster; + Qwen/Qwen-Image-->Distill-Series; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full; + Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA; + Qwen/Qwen-Image-->ControlNet-Series; + ControlNet-Series-->Blockwise-ControlNet-Series; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth; + Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint; + ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union; + Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix; +``` + +
+ +## 快速开始 + +通过运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from PIL import Image +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe( + prompt, seed=0, num_inference_steps=40, + # edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit +) +image.save("image.jpg") +``` + +## 模型总览 + +|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-| +|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)| +|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)| +|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)| +|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)| +|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)| +|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-| + +## 模型推理 + +模型通过 `QwenImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/Pipeline_Usage/Model_Inference.md#加载模型)。 + +`QwenImagePipeline` 推理的输入参数包括: + +* `prompt`: 提示词,描述画面中出现的内容。 +* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。 +* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4,当设置为 1 时不再生效。 +* `input_image`: 输入图像,用于图生图,该参数与 `denoising_strength` 配合使用。 +* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1,当数值接近 0 时,生成图像与输入图像相似;当数值接近 1 时,生成图像与输入图像相差更大。在不输入 `input_image` 参数时,请不要将其设置为非 1 的数值。 +* `inpaint_mask`: 图像局部重绘的遮罩图像。 +* `inpaint_blur_size`: 图像局部重绘的边缘柔化宽度。 +* `inpaint_blur_sigma`: 图像局部重绘的边缘柔化强度。 +* `height`: 图像高度,需保证高度为 16 的倍数。 +* `width`: 图像宽度,需保证宽度为 16 的倍数。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `num_inference_steps`: 推理次数,默认值为 30。 +* `exponential_shift_mu`: 在采样时间步时采用的固定参数,留空则根据图像宽高进行采样。 +* `blockwise_controlnet_inputs`: Blockwise ControlNet 模型的输入。 +* `eligen_entity_prompts`: EliGen 分区控制的提示词。 +* `eligen_entity_masks`: EliGen 分区控制的区域遮罩图像。 +* `eligen_enable_on_negative`: 是否在 CFG 的负向一侧启用 EliGen 分区控制。 +* `edit_image`: 编辑模型的待编辑图像,支持多张图像。 +* `edit_image_auto_resize`: 是否自动缩放待编辑图像。 +* `edit_rope_interpolation`: 是否在低分辨率编辑图像上启用 ROPE 插值。 +* `context_image`: In-Context Control 的输入图像。 +* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。 +* `tile_size`: VAE 编解码阶段的分块大小,默认为 128,仅在 `tiled=True` 时生效。 +* `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。 +* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 + +如果显存不足,请开启[显存管理](/docs/Pipeline_Usage/VRAM_management.md)。 + +## 模型训练 + +模型训练脚本位于 `examples/qwen_image/model_training/train.py`,脚本的输入参数包括[基础脚本参数](/docs/Pipeline_Usage/Model_Training.md#脚本参数)以及以下额外参数: + +* `--tokenizer_path`: tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。 +* `--processor_path`: processor 的路径,适用于图像编辑模型,留空则自动从远程下载。 + +`--task` 参数支持 `sft`([标准监督训练](/docs/Training/Supervised_Fine_Tuning.md))与 `direct_distill`([直接蒸馏](/docs/Training/Direct_Distill.md)),两者都支持[两阶段拆分训练](/docs/Training/Split_Training.md)和[FP8 精度](/docs/Training/FP8_Precision.md)。 + +使用命令 `modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset` 可下载样例数据集。我们为每个模型编写了推荐的训练命令,详见[模型总览](#模型总览)中的表格。详细的训练流程,请参考[模型训练](/docs/Pipeline_Usage/Model_Training.md)。 diff --git a/docs/QA.md b/docs/QA.md index 752e3ea..b1d55df 100644 --- a/docs/QA.md +++ b/docs/QA.md @@ -2,10 +2,27 @@ ## 为什么训练框架不支持 batch size > 1? +* **更大的 batch size 已无法实现显著加速**:由于 flash attention 等加速技术已经充分提高了 GPU 的利用率,因此更大的 batch size 只会带来更大的显存占用,无法带来显著加速。在 Stable Diffusion 1.5 这类小模型上的经验已不再适用于最新的大模型。 +* **更大的 batch size 可以用其他方案实现**:多 GPU 训练和 Gradient Accumulation 都可以在数学意义上等价地实现更大的 batch size。 +* **更大的 batch size 与框架的通用性设计相悖**:我们希望构建通用的训练框架,大量模型无法适配更大的 batch size,例如不同长度的文本编码、不同分辨率的图像等,都是无法合并为更大的 batch 的。 + ## 为什么不删除某些模型中的冗余参数? +在部分模型中,模型存在冗余参数,例如 Qwen-Image 的 DiT 模型最后一层的文本部分,这部分参数不会参与任何计算,这是模型开发者留下的小 bug。直接将其设置为可训练时还会在多 GPU 训练中出现报错。 + +为了与开源社区中其他模型保持兼容性,我们决定保留这些参数。这些冗余参数在多 GPU 训练中可以通过 `--find_unused_parameters` 参数避免报错。 + ## 为什么 FP8 量化没有任何加速效果? +原生 FP8 计算需要依赖 Hopper 架构的 GPU,同时在计算精度上有较大误差,目前仍然是不成熟的技术,因此本项目不支持原生 FP8 计算。 + +显存管理中的 FP8 计算是指将模型参数以 FP8 精度存储在内存或显存中,在需要计算时临时转换为其他精度,因此仅能减少显存占用,没有加速效果。 + ## 为什么训练框架不支持原生 FP8 精度训练? -即使硬件条件允许,我们目前也没有任何支持原生 FP8 精度训练的规划。目前原生 FP8 精度训练的主要挑战是梯度爆炸导致的精度溢出,为了保证训练的稳定性,需针对性地重新设计模型结构,然而目前还没有任何模型开发者愿意这么做。此外,使用原生 FP8 精度训练的模型,在推理时若没有 Hopper 架构 GPU,则只能以 BF16 精度进行计算,理论上其生成效果反而不如 FP8。因此,原生 FP8 精度训练技术是极不成熟的,我们静观开源社区的技术发展。 +即使硬件条件允许,我们目前也没有任何支持原生 FP8 精度训练的规划。 + +* 目前原生 FP8 精度训练的主要挑战是梯度爆炸导致的精度溢出,为了保证训练的稳定性,需针对性地重新设计模型结构,然而目前还没有任何模型开发者愿意这么做。 +* 此外,使用原生 FP8 精度训练的模型,在推理时若没有 Hopper 架构 GPU,则只能以 BF16 精度进行计算,理论上其生成效果反而不如 FP8。 + +因此,原生 FP8 精度训练技术是极不成熟的,我们静观开源社区的技术发展。 diff --git a/docs/README.md b/docs/README.md index d689601..9e71f31 100644 --- a/docs/README.md +++ b/docs/README.md @@ -17,7 +17,7 @@ 本节介绍 `DiffSynth-Studio` 所支持的 Diffusion 模型,部分模型 Pipeline 具备可控生成、并行加速等特色功能。 * [模型目录](./Model_Details/Overview.md) -* [Qwen-Image](./Model_Details/Qwen-Image.md)【TODO】 +* [Qwen-Image](./Model_Details/Qwen-Image.md) * [FLUX](./Model_Details/FLUX.md)【TODO】 * [Wan](./Model_Details/Wan.md)【TODO】 @@ -29,8 +29,8 @@ * [标准监督训练](./Training/Supervised_Fine_Tuning.md) * [在训练中启用 FP8 精度](./Training/FP8_Precision.md) * [端到端的蒸馏加速训练](./Training/Direct_Distill.md) -* 两阶段拆分训练 -* 差分 LoRA 训练 +* [两阶段拆分训练](./Training/Split_Training.md) +* [差分 LoRA 训练](./Training/Differential_LoRA.md) ## Section 4: 模型接入 @@ -64,4 +64,4 @@ 本节总结了开发者常见的问题,如果你在使用和开发中遇到了问题,请参考本节内容,如果仍无法解决,请到 GitHub 上给我们提 issue。 -* [常见问题](./QA.md)【TODO】 +* [常见问题](./QA.md) diff --git a/docs/Training/Differential_LoRA.md b/docs/Training/Differential_LoRA.md new file mode 100644 index 0000000..069b4da --- /dev/null +++ b/docs/Training/Differential_LoRA.md @@ -0,0 +1,38 @@ +# 差分 LoRA 训练 + +差分 LoRA 训练是一种特殊的 LoRA 训练方式,旨在让模型学习图像之间的差异。 + +## 训练方案 + +我们未能找到差分 LoRA 训练最早由谁提出,这一技术已经在开源社区中流传甚久。 + +假设我们有两张内容相似的图像:图 1 和图 2。例如两张图中分别有一辆车,但图 1 中画面细节更少,图 2 中画面细节更多。在差分 LoRA 训练中,我们进行两步训练: + +* 以图 1 为训练数据,以[标准监督训练](./Supervised_Fine_Tuning.md)的方式,训练 LoRA 1 +* 以图 2 为训练数据,将 LoRA 1 融入基础模型后,以[标准监督训练](./Supervised_Fine_Tuning.md)的方式,训练 LoRA 2 + +在第一步训练中,由于训练数据仅有一张图,LoRA 模型很容易过拟合,因此训练完成后,LoRA 1 会让模型毫不犹豫地生成图 1,无论随机种子是什么。在第二步训练中,LoRA 模型再次过拟合,因此训练完成后,在 LoRA 1 和 LoRA 2 的共同作用下,模型会毫不犹豫地生成图 2。简言之: + +* LoRA 1 = 生成图 1 +* LoRA 1 + LoRA 2 = 生成图 2 + +此时丢弃 LoRA 1,只使用 LoRA 2,模型将会理解图 1 和图 2 的差异,使生成的内容倾向于“更不像图1,更像图 2”。 + +单一训练数据可以保证模型能够过拟合到训练数据上,但稳定性不足。为了提高稳定性,我们可以用多个图像对(image pairs)进行训练,并将训练出的 LoRA 2 进行平均,得到效果更稳定的 LoRA。 + +用这一训练方案,可以训练出一些功能奇特的 LoRA 模型。例如,使用丑陋的和漂亮的图像对,训练提升图像美感的 LoRA;使用细节少的和细节丰富的图像对,训练增加图像细节的 LoRA。 + +## 模型效果 + +我们用差分 LoRA 训练技术训练了几个美学提升 LoRA,可前往对应的模型页面查看生成效果。 + +* [DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1) +* [DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) + +## 在训练框架中使用差分 LoRA 训练 + +第一步的训练与普通 LoRA 训练没有任何差异,在第二步的训练命令中,通过 `--preset_lora_path` 参数填入第一步的 LoRA 模型文件路径,并将 `--preset_lora_model` 设置为与 `lora_base_model` 相同的参数,即可将 LoRA 1 加载到基础模型中。 + +## 框架设计思路 + +在训练框架中,`--preset_lora_path` 指向的模型在 `DiffusionTrainingModule` 的 `switch_pipe_to_training_mode` 中完成加载。 diff --git a/docs/Training/Split_Training.md b/docs/Training/Split_Training.md index 0422063..d9f8374 100644 --- a/docs/Training/Split_Training.md +++ b/docs/Training/Split_Training.md @@ -1,2 +1,97 @@ # 两阶段拆分训练 +本文档介绍拆分训练,能够自动将训练过程拆分为两阶段进行,减少显存占用,同时加快训练速度。 + +(拆分训练是实验性特性,尚未进行大规模验证,如果在使用中出现问题,请在 GitHub 上给我们提 issue。) + +## 拆分训练 + +在大部分模型的训练过程中,大量计算发生在“前处理”中,即“与去噪模型无关的计算”,包括 VAE 编码、文本编码等。当对应的模型参数固定时,这部分计算的结果是重复的,在多个 epoch 中每个数据样本的计算结果完全相同,因此我们提供了“拆分训练”功能,该功能可以自动分析并拆分训练过程。 + +对于普通文生图模型的标准监督训练,拆分过程是非常简单的,只需要把所有 [`Pipeline Units`](/docs/Developer_Guide/Building_a_Pipeline.md#units) 的计算拆分到第一阶段,将计算结果存储到硬盘中,然后在第二阶段从硬盘中读取这些结果并进行后续计算即可。但如果前处理过程中需要梯度回传,情况就变得极其复杂,为此,我们引入了一个计算图拆分算法用于分析如何拆分计算。 + +## 计算图拆分算法 + +> (我们会在后续的文档更新中补充计算图拆分算法的详细细节) + +## 使用拆分训练 + +拆分训练已支持[标准监督训练](./Supervised_Fine_Tuning.md)和[直接蒸馏训练](./Direct_Distill.md),在训练命令中通过 `--task` 参数控制,以 Qwen-Image 模型的 LoRA 训练为例,拆分前的训练命令为: + +```shell +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters +``` + +拆分后,在第一阶段中,做如下修改: + +* 将 `--dataset_repeat` 改为 1,避免重复计算 +* 将 `--output_path` 改为第一阶段计算结果保存的路径 +* 添加额外参数 `--task "sft:data_process"` +* 删除 `--model_id_with_origin_paths` 中的 DiT 模型 + +```shell +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task "sft:data_process" +``` + +在第二阶段,做如下修改: + +* 将 `--dataset_base_path` 改为第一阶段的 `--output_path` +* 删除 `--dataset_metadata_path` +* 添加额外参数 `--task "sft:train"` +* 删除 `--model_id_with_origin_paths` 中的 Text Encoder 和 VAE 模型 + +```shell +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path "./models/train/Qwen-Image-LoRA-splited-cache" \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-LoRA-splited" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --task "sft:train" +``` + +我们提供了样例训练脚本和验证脚本,位于 `examples/qwen_image/model_training/special/split_training`。 + +## 训练框架设计思路 + +训练框架通过 `DiffusionTrainingModule` 的 `split_pipeline_units` 方法拆分 `Pipeline` 中的计算单元。 diff --git a/docs/Training/Supervised_Fine_Tuning.md b/docs/Training/Supervised_Fine_Tuning.md index 3ba9018..4b1d0f0 100644 --- a/docs/Training/Supervised_Fine_Tuning.md +++ b/docs/Training/Supervised_Fine_Tuning.md @@ -1,6 +1,6 @@ # 标准监督训练 -在理解 [Diffusion 模型基本原理](./Understanding_Diffusion_models.md)之后,本文档介绍框架如何实现 Diffusion 模型的训练。 +在理解 [Diffusion 模型基本原理](./Understanding_Diffusion_models.md)之后,本文档介绍框架如何实现 Diffusion 模型的训练。本文档介绍框架的原理,帮助开发者编写新的训练代码,如需使用我们提供的默认训练功能,请参考[模型训练](/docs/Pipeline_Usage/Model_Training.md)。 回顾前文中的模型训练伪代码,当我们实际编写代码时,情况会变得极为复杂。部分模型需要输入额外的引导条件并进行预处理,例如 ControlNet;部分模型需要与去噪模型进行交叉式的计算,例如 VACE;部分模型因显存需求过大,需要开启 Gradient Checkpointing,例如 Qwen-Image 的 DiT。 diff --git a/examples/qwen_image/model_training/special/differential_training/Qwen-Image-LoRA.sh b/examples/qwen_image/model_training/special/differential_training/Qwen-Image-LoRA.sh index 2823b77..19191dd 100644 --- a/examples/qwen_image/model_training/special/differential_training/Qwen-Image-LoRA.sh +++ b/examples/qwen_image/model_training/special/differential_training/Qwen-Image-LoRA.sh @@ -36,4 +36,5 @@ accelerate launch examples/qwen_image/model_training/train.py \ --use_gradient_checkpointing \ --dataset_num_workers 8 \ --find_unused_parameters \ - --preset_lora_path "./models/train/Qwen-Image-LoRA-deterministic/epoch-4.safetensors" + --preset_lora_path "./models/train/Qwen-Image-LoRA-deterministic/epoch-4.safetensors" \ + --preset_lora_model "dit"