flux-kontext

This commit is contained in:
Artiprocher
2025-06-29 15:51:45 +08:00
parent 009f26bb40
commit 8c226e83a6
4 changed files with 442 additions and 2 deletions

View File

@@ -23,7 +23,9 @@ from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
from ..lora.flux_lora import FluxLoRALoader
from ..vram_management import gradient_checkpoint_forward
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
from ..models.flux_dit import RMSNorm
from ..vram_management import gradient_checkpoint_forward, enable_vram_management, AutoWrappedModule, AutoWrappedLinear
@@ -135,7 +137,119 @@ class FluxImagePipeline(BasePipeline):
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5):
pass
self.vram_management_enabled = True
if num_persistent_param_in_dit is not None:
vram_limit = None
else:
if vram_limit is None:
vram_limit = self.get_vram()
vram_limit = vram_limit - vram_buffer
if self.text_encoder_1 is not None:
dtype = next(iter(self.text_encoder_1.parameters())).dtype
enable_vram_management(
self.text_encoder_1,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
torch.nn.LayerNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.text_encoder_2 is not None:
dtype = next(iter(self.text_encoder_2.parameters())).dtype
enable_vram_management(
self.text_encoder_2,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Embedding: AutoWrappedModule,
T5LayerNorm: AutoWrappedModule,
T5DenseActDense: AutoWrappedModule,
T5DenseGatedActDense: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.dit is not None:
dtype = next(iter(self.dit.parameters())).dtype
device = "cpu" if vram_limit is not None else self.device
enable_vram_management(
self.dit,
module_map = {
RMSNorm: AutoWrappedModule,
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device=device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
max_num_param=num_persistent_param_in_dit,
overflow_module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae_decoder is not None:
dtype = next(iter(self.vae_decoder.parameters())).dtype
enable_vram_management(
self.vae_decoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
if self.vae_encoder is not None:
dtype = next(iter(self.vae_encoder.parameters())).dtype
enable_vram_management(
self.vae_encoder,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
torch.nn.Conv2d: AutoWrappedModule,
torch.nn.GroupNorm: AutoWrappedModule,
},
module_config = dict(
offload_dtype=dtype,
offload_device="cpu",
onload_dtype=dtype,
onload_device="cpu",
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=vram_limit,
)
@staticmethod

0
examples/flux/README.md Normal file
View File

326
examples/flux/README_zh.md Normal file
View File

@@ -0,0 +1,326 @@
# FLUX
[Switch to English](./README.md)
FLUX 是由 Black-Forest-Labs 开源的一系列图像生成模型。
**DiffSynth-Studio 启用了新的推理和训练框架,如需使用旧版本,请点击[这里](https://github.com/modelscope/DiffSynth-Studio/tree/3edf3583b1f08944cee837b94d9f84d669c2729c)。**
## 安装
在使用本系列模型之前,请通过源码安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
## 快速开始
通过运行以下代码可以快速加载 FLUX.1-dev 模型并进行推理。
```python
import torch
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
image = pipe(prompt="a cat", seed=0)
image.save("image.jpg")
```
## 模型总览
**FLUX 系列模型的全新框架支持正在开发中,敬请期待!**
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[black-forest-labs/FLUX.1-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](./model_inference/FLUX.1-dev.py)|[code](./model_training/full/FLUX.1-dev.sh)|[code](./model_training/validate_full/FLUX.1-dev.py)|[code](./model_training/lora/FLUX.1-dev.sh)|[code](./model_training/validate_lora/FLUX.1-dev.py)|
|[black-forest-labs/FLUX.1-Kontext-dev](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](./model_inference/FLUX.1-Kontext-dev.py)|||[code](./model_training/lora/FLUX.1-Kontext-dev.sh)|[code](./model_training/validate_lora/FLUX.1-Kontext-dev.py)|
## 模型推理
以下部分将会帮助您理解我们的功能并编写推理代码。
<details>
<summary>加载模型</summary>
模型通过 `from_pretrained` 加载:
```python
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
],
)
```
其中 `torch_dtype``device` 是计算精度和计算设备。`model_configs` 可通过多种方式配置模型路径:
* 从[魔搭社区](https://modelscope.cn/)下载模型并加载。此时需要填写 `model_id``origin_file_pattern`,例如
```python
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
```
* 从本地文件路径加载模型。此时需要填写 `path`,例如
```python
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors")
```
对于从多个文件加载的单一模型,使用列表即可,例如
```python
ModelConfig(path=[
"models/xxx/diffusion_pytorch_model-00001-of-00003.safetensors",
"models/xxx/diffusion_pytorch_model-00002-of-00003.safetensors",
"models/xxx/diffusion_pytorch_model-00003-of-00003.safetensors",
])
```
`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为:
* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`
* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`
</details>
<details>
<summary>显存管理</summary>
DiffSynth-Studio 为 FLUX 模型提供了细粒度的显存管理,让模型能够在低显存设备上进行推理,可通过以下代码开启 offload 功能,在显存有限的设备上将部分模块 offload 到内存中。
```python
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", offload_device="cpu"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", offload_device="cpu"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/", offload_device="cpu"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", offload_device="cpu"),
],
)
pipe.enable_vram_management()
```
`enable_vram_management` 函数提供了以下参数,用于控制显存使用情况:
* `vram_limit`: 显存占用量GB默认占用设备上的剩余显存。注意这不是一个绝对限制当设置的显存不足以支持模型进行推理但实际可用显存足够时将会以最小化显存占用的形式进行推理。将其设置为0时将会实现理论最小显存占用。
* `vram_buffer`: 显存缓冲区大小GB默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。
* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。
</details>
<details>
<summary>推理加速</summary>
* TeaCache加速技术 [TeaCache](https://github.com/ali-vilab/TeaCache),请参考[示例代码](./acceleration/teacache.py)。
</details>
<details>
<summary>输入参数</summary>
Pipeline 在推理阶段能够接收以下输入参数:
* `prompt`: 提示词,描述画面中出现的内容。
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1当设置为大于1的数值时生效。
* `embedded_guidance`: FLUX-dev 的内嵌引导参数,默认值为 3.5。
* `t5_sequence_length`: T5 模型的文本向量序列长度,默认值为 512。
* `input_image`: 输入图像,用于图生图,该参数与 `denoising_strength` 配合使用。
* `denoising_strength`: 去噪强度,范围是 01默认值为 1当数值接近 0 时,生成图像与输入图像相似;当数值接近 1 时,生成图像与输入图像相差更大。在不输入 `input_image` 参数时,请不要将其设置为非 1 的数值。
* `height`: 图像高度,需保证高度为 16 的倍数。
* `width`: 图像宽度,需保证宽度为 16 的倍数。
* `seed`: 随机种子。默认为 `None`,即完全随机。
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
* `sigma_shift`: Rectified Flow 理论中的参数,默认为 3。数值越大模型在去噪的开始阶段停留的步骤数越多可适当调大这个参数来提高画面质量但会因生成过程与训练过程不一致导致生成的视频内容与训练数据存在差异。
* `num_inference_steps`: 推理次数,默认值为 30。
* `kontext_images`: Kontext 模型的输入图像。
* `controlnet_inputs`: ControlNet 模型的输入。
* `ipadapter_images`: IP-Adapter 模型的输入图像。
* `ipadapter_scale`: IP-Adapter 模型的控制强度。
</details>
## 模型训练
FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_training/train.py) 脚本进行。
<details>
<summary>脚本参数</summary>
脚本包含以下参数:
* 数据集
* `--dataset_base_path`: 数据集的根路径。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--height`: 图像或视频的高度。将 `height``width` 留空以启用动态分辨率。
* `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。
* 训练
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch数量。
* `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪一层上。
* `--lora_rank`: LoRA 的秩Rank
* 额外模型输入
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
* 显存管理
* `use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `gradient_accumulation_steps`: 梯度累积步数。
* 其他
* `--align_to_opensource_format`: 是否将 FLUX DiT LoRA 的格式与开源版本对齐,仅对 FLUX.1-dev 和 FLUX.1-Kontext-dev 的 LoRA 训练生效。
此外,训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,在开始训练前运行 `accelerate config` 可配置 GPU 的相关参数。对于部分模型训练(例如模型的全量训练)脚本,我们提供了建议的 `accelerate` 配置文件,可在对应的训练脚本中查看。
</details>
<details>
<summary>Step 1: 准备数据集</summary>
数据集包含一系列文件,我们建议您这样组织数据集文件:
```
data/example_video_dataset/
├── metadata.csv
├── image1.jpg
└── image2.jpg
```
其中 `image1.jpg``image2.jpg` 为训练用视频数据,`metadata.csv` 为元数据列表,例如
```
video,prompt
image1.jpg,"a cat is sleeping"
image2.jpg,"a dog is running"
```
我们构建了一个样例视频数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
```shell
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
```
数据集支持多种图片格式,`"jpg", "jpeg", "png", "webp"`
图片的尺寸可通过脚本参数 `--height``--width` 控制。当 `--height``--width` 为空时将会开启动态分辨率,按照数据集中每个视频或图片的实际宽高训练。
**我们强烈建议使用固定分辨率训练,因为在多卡训练中存在负载均衡问题。**
当模型需要额外输入时,例如具备控制能力的模型 [`black-forest-labs/FLUX.1-Kontext-dev`](https://modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) 所需的 `kontext_images`,请在数据集中补充相应的列,例如:
```
video,prompt,kontext_images
image1.jpg,"a cat is sleeping",image1_reference.jpg
```
额外输入若包含视频和图像文件,则需要在 `--data_file_keys` 参数中指定要解析的列名。可根据额外输入增加相应的列名,例如 `--data_file_keys "image,kontext_images"`
</details>
<details>
<summary>Step 2: 加载模型</summary>
类似于推理时的模型加载逻辑,可直接通过模型 ID 配置要加载的模型。例如,推理时我们通过以下设置加载模型
```python
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
]
```
那么在训练时,填入以下参数即可加载对应的模型。
```shell
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors"
```
如果您希望从本地文件加载模型,例如推理时
```python
model_configs=[
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors"),
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors"),
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/text_encoder_2/"),
ModelConfig(path="models/black-forest-labs/FLUX.1-dev/ae.safetensors"),
]
```
那么训练时需设置为
```shell
--model_paths '[
"models/black-forest-labs/FLUX.1-dev/flux1-dev.safetensors",
"models/black-forest-labs/FLUX.1-dev/text_encoder/model.safetensors",
"models/black-forest-labs/FLUX.1-dev/text_encoder_2/",
"models/black-forest-labs/FLUX.1-dev/ae.safetensors"
]' \
```
</details>
<details>
<summary>Step 3: 设置可训练模块</summary>
训练框架支持训练基础模型,或 LoRA 模型。以下是几个例子:
* 全量训练 DiT 部分:`--trainable_models dit`
* 训练 DiT 部分的 LoRA 模型:`--lora_base_model dit --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" --lora_rank 32`
此外由于训练脚本中加载了多个模块text encoder、dit、vae保存模型文件时需要移除前缀例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时,请设置 `--remove_prefix_in_ckpt pipe.dit.`
</details>
<details>
<summary>Step 4: 启动训练程序</summary>
我们为每一个模型编写了训练命令,请参考本文档开头的表格。
</details>