Files
DiffSynth-Studio/docs/zh/Diffusion_Templates/Template_Model_Training.md
Artiprocher f58ba5a784 update docs
2026-04-16 20:24:22 +08:00

12 KiB
Raw Blame History

Template 模型训练

DiffSynth-Studio 目前已为 black-forest-labs/FLUX.2-klein-base-4B 提供了全面的 Templates 训练支持,更多模型的适配敬请期待。

基于预训练 Template 模型继续训练

如需基于我们预训练好的模型进行继续训练,请参考FLUX.2 中的表格,找到对应的训练脚本。

构建新的 Template 模型

Template 模型组件格式

一个 Template 模型与一个模型库(或一个本地文件夹)绑定,模型库中有代码文件 model.py 作为唯一入口。model.py 的模板如下:

import torch

class CustomizedTemplateModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @torch.no_grad()
    def process_inputs(self, xxx, **kwargs):
        yyy = xxx
        return {"yyy": yyy}

    def forward(self, yyy, **kwargs):
        zzz = yyy
        return {"zzz": zzz}

class DataProcessor:
    def __call__(self, www, **kwargs):
        xxx = www
        return {"xxx": xxx}

TEMPLATE_MODEL = CustomizedTemplateModel
TEMPLATE_MODEL_PATH = "model.safetensors"
TEMPLATE_DATA_PROCESSOR = DataProcessor

在 Template 模型推理时Template Input 先后经过 TEMPLATE_MODELprocess_inputsforward 得到 Template Cache。

flowchart LR;
    i@{shape: text, label: "Template Input"}-->p[process_inputs];
    subgraph TEMPLATE_MODEL
        p[process_inputs]-->f[forward]
    end
    f[forward]-->c@{shape: text, label: "Template Cache"};

在 Template 模型训练时Template Input 不再是用户的输入,而是从数据集中获取,由 TEMPLATE_DATA_PROCESSOR 进行计算得到。

flowchart LR;
    d@{shape: text, label: "Dataset"}-->dp[TEMPLATE_DATA_PROCESSOR]-->p[process_inputs];
    subgraph TEMPLATE_MODEL
        p[process_inputs]-->f[forward]
    end
    f[forward]-->c@{shape: text, label: "Template Cache"};

TEMPLATE_MODEL

TEMPLATE_MODEL 是 Template 模型的代码实现,需继承 torch.nn.Module,并编写 process_inputsforward 两个函数。process_inputsforward 构成完整的 Template 模型推理过程,我们将其拆分为两部分,是为了在训练中更容易适配两阶段拆分训练

  • process_inputs 需带有装饰器 @torch.no_grad(),进行不包含梯度的计算
  • forward 需包含训练模型所需的全部梯度计算过程,其输入与 process_inputs 的输出相同

process_inputsforward 需包含 **kwargs,保证兼容性,此外,我们提供了以下预留的参数

  • 如需在 process_inputsforward 中和基础模型 Pipeline 进行交互,例如调用基础模型 Pipeline 中的文本编码器进行计算,可在 process_inputsforward 的输入参数中增加字段 pipe
  • 如需在训练中启用 Gradient Checkpointing可在 forward 的输入参数中增加字段 use_gradient_checkpointinguse_gradient_checkpointing_offload
  • 多个 Template 模型需通过 model_id 区分 Template Inputs请不要在 process_inputsforward 的输入参数中使用这个字段

TEMPLATE_MODEL_PATH(可选项)

TEMPLATE_MODEL_PATH 是模型预训练权重文件的相对路径,例如

TEMPLATE_MODEL_PATH = "model.safetensors"

如需从多个模型文件中加载,可使用列表

TEMPLATE_MODEL_PATH = [
    "model-00001-of-00003.safetensors",
    "model-00002-of-00003.safetensors",
    "model-00003-of-00003.safetensors",
]

如果需要随机初始化模型参数(模型还未训练),或不需要初始化模型参数,可将其设置为 None,或不设置

TEMPLATE_MODEL_PATH = None

TEMPLATE_DATA_PROCESSOR(可选项)

如需使用 DiffSynth-Studio 训练 Template 模型,则需构建训练数据集,数据集中的 metadata.json 包含 template_inputs 字段。metadata.json 中的 template_inputs 并不是直接输入给 Template 模型 process_inputs 的参数,而是提供给 TEMPLATE_DATA_PROCESSOR 的输入参数,由 TEMPLATE_DATA_PROCESSOR 计算出输入给 Template 模型 process_inputs 的参数。

例如,DiffSynth-Studio/F2KB4B-Template-Brightness 这一亮度控制模型的输入参数是 scale,即图像的亮度数值。scale 可以直接写在 metadata.json 中,此时 TEMPLATE_DATA_PROCESSOR 只需要传递参数:

[
    {
        "image": "images/image_1.jpg",
        "prompt": "a cat",
        "template_inputs": {"scale": 0.2}
    },
    {
        "image": "images/image_2.jpg",
        "prompt": "a dog",
        "template_inputs": {"scale": 0.6}
    }
]
class DataProcessor:
    def __call__(self, scale, **kwargs):
        return {"scale": scale}

TEMPLATE_DATA_PROCESSOR = DataProcessor

也可在 metadata.json 中填写图像路径,直接在训练过程中计算 scale

[
    {
        "image": "images/image_1.jpg",
        "prompt": "a cat",
        "template_inputs": {"image": "/path/to/your/dataset/images/image_1.jpg"}
    },
    {
        "image": "images/image_2.jpg",
        "prompt": "a dog",
        "template_inputs": {"image": "/path/to/your/dataset/images/image_1.jpg"}
    }
]
class DataProcessor:
    def __call__(self, image, **kwargs):
        image = Image.open(image)
        image = np.array(image)
        return {"scale": image.astype(np.float32).mean() / 255}

TEMPLATE_DATA_PROCESSOR = DataProcessor

训练 Template 模型

Template 模型“可训练”的充分条件是Template Cache 中的变量计算与基础模型 Pipeline 完全解耦,这些变量在推理过程中输入给基础模型 Pipeline 后,不会参与任何 Pipeline Unit 的计算,直达 model_fn

如果 Template 模型是“可训练”的,那么可以使用 DiffSynth-Studio 进行训练,以基础模型 black-forest-labs/FLUX.2-klein-base-4B 为例,在训练脚本中,填写字段:

  • --extra_inputs:额外输入,训练文生图模型的 Template 模型时只需填 template_inputs,训练图像编辑模型的 Template 模型时需填 edit_image,template_inputs
  • --template_model_id_or_pathTemplate 模型的魔搭模型 ID 或本地路径,框架会优先匹配本地路径,若本地路径不存在则从魔搭模型库中下载该模型,填写模型 ID 时,以“:”结尾,例如 "DiffSynth-Studio/Template-KleinBase4B-Brightness:"
  • --remove_prefix_in_ckpt:保存模型文件时,移除的 state dict 变量名前缀,填 "pipe.template_model." 即可
  • --trainable_models:可训练模型,填写 template_model 即可,若只需训练其中的某个组件,则需填写 template_model.xxx,template_model.yyy,以逗号分隔

以下是一个样例训练脚本,它会自动下载一个样例数据集,随机初始化模型权重后开始训练亮度控制模型:

modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "flux2/Template-KleinBase4B-Brightness/*" --local_dir ./data/diffsynth_example_dataset

accelerate launch examples/flux2/model_training/train.py \
  --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness \
  --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness/metadata.jsonl \
  --extra_inputs "template_inputs" \
  --max_pixels 1048576 \
  --dataset_repeat 50 \
  --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
  --template_model_id_or_path "examples/flux2/model_training/scripts/brightness" \
  --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
  --learning_rate 1e-4 \
  --num_epochs 2 \
  --remove_prefix_in_ckpt "pipe.template_model." \
  --output_path "./models/train/Template-KleinBase4B-Brightness_example" \
  --trainable_models "template_model" \
  --use_gradient_checkpointing \
  --find_unused_parameters

与基础模型 Pipeline 组件交互

Diffusion Template 框架允许 Template 模型与基础模型 Pipeline 进行交互。例如,你可能需要使用基础模型 Pipeline 中的 text encoder 对文本进行编码,此时在 process_inputsforward 中使用预留字段 pipe 即可。

import torch

class CustomizedTemplateModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.xxx = xxx()

    @torch.no_grad()
    def process_inputs(self, text, pipe, **kwargs):
        input_ids = pipe.tokenizer(text)
        text_emb = pipe.text_encoder(text_emb)
        return {"text_emb": text_emb}

    def forward(self, text_emb, pipe, **kwargs):
        kv_cache = self.xxx(text_emb)
        return {"kv_cache": kv_cache}

TEMPLATE_MODEL = CustomizedTemplateModel

使用非训练的模型组件

在设计 Template 模型时,如果需要使用预训练的模型且不希望在训练过程中更新这部分参数,例如

import torch

class CustomizedTemplateModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.image_encoder = XXXEncoder.from_pretrained(xxx)
        self.mlp = MLP()

    @torch.no_grad()
    def process_inputs(self, image, **kwargs):
        emb = self.image_encoder(image)
        return {"emb": emb}

    def forward(self, emb, **kwargs):
        kv_cache = self.mlp(emb)
        return {"kv_cache": kv_cache}

TEMPLATE_MODEL = CustomizedTemplateModel

此时需在训练命令中通过参数 --trainable_models template_model.mlp 设置为仅训练 mlp 部分。

上传 Template 模型

完成训练后,按照以下步骤可上传 Template 模型到魔搭社区

Step 1model.py 中填入训练好的模型文件名,例如

TEMPLATE_MODEL_PATH = "model.safetensors"

Step 2使用以下命令上传 model.py,其中 --token ms-xxxhttps://modelscope.cn/my/access/token 获取

modelscope upload user_name/your_model_id /path/to/your/model.py model.py --token ms-xxx

Step 3确认模型文件

确认要上传的模型文件,例如 epoch-1.safetensorsstep-2000.safetensors

注意DiffSynth-Studio 保存的模型文件中只包含可训练的参数,如果模型中包括非训练参数,则需要重新将非训练的模型参数打包才能进行推理,你可以通过以下代码进行打包:

from diffsynth.diffusion.template import load_template_model, load_state_dict
from safetensors.torch import save_file
import torch

model = load_template_model("path/to/your/template/model", torch_dtype=torch.bfloat16, device="cpu")
state_dict = load_state_dict("path/to/your/ckpt/epoch-1.safetensors", torch_dtype=torch.bfloat16, device="cpu")
state_dict.update(model.state_dict())
save_file(state_dict, "model.safetensors")

Step 4上传模型文件

modelscope upload user_name/your_model_id /path/to/your/model/epoch-1.safetensors model.safetensors --token ms-xxx

Step 5验证模型推理效果

from diffsynth.diffusion.template import TemplatePipeline
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch

# Load base model
pipe = Flux2ImagePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
        ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
        ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
    ],
    tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
# Load Template model
template_pipeline = TemplatePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="user_name/your_model_id")
    ],
)
# Generate an image
image = template_pipeline(
    pipe,
    prompt="a cat",
    seed=0, cfg_scale=4,
    height=1024, width=1024,
    template_inputs=[{xxx}],
)
image.save("image.png")