Files
DiffSynth-Studio/docs/en/Diffusion_Templates/Template_Model_Training.md
Artiprocher f58ba5a784 update docs
2026-04-16 20:24:22 +08:00

10 KiB

Template Model Training

DiffSynth-Studio currently provides comprehensive Template training support for black-forest-labs/FLUX.2-klein-base-4B, with more model adaptations coming soon.

Continuing Training from Pretrained Models

To continue training from our pretrained models, refer to the table in FLUX.2 to find the corresponding training script.

Building New Template Models

Template Model Component Format

A Template model binds to a model repository (or local folder) containing a code file model.py as the entry point. Here's the template for model.py:

import torch

class CustomizedTemplateModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @torch.no_grad()
    def process_inputs(self, xxx, **kwargs):
        yyy = xxx
        return {"yyy": yyy}

    def forward(self, yyy, **kwargs):
        zzz = yyy
        return {"zzz": zzz}

class DataProcessor:
    def __call__(self, www, **kwargs):
        xxx = www
        return {"xxx": xxx}

TEMPLATE_MODEL = CustomizedTemplateModel
TEMPLATE_MODEL_PATH = "model.safetensors"
TEMPLATE_DATA_PROCESSOR = DataProcessor

During Template model inference, Template Input passes through TEMPLATE_MODEL's process_inputs and forward to generate Template Cache.

flowchart LR;
    i@{shape: text, label: "Template Input"}-->p[process_inputs];
    subgraph TEMPLATE_MODEL
        p[process_inputs]-->f[forward]
    end
    f[forward]-->c@{shape: text, label: "Template Cache"};

During Template model training, Template Input comes from the dataset through TEMPLATE_DATA_PROCESSOR.

flowchart LR;
    d@{shape: text, label: "Dataset"}-->dp[TEMPLATE_DATA_PROCESSOR]-->p[process_inputs];
    subgraph TEMPLATE_MODEL
        p[process_inputs]-->f[forward]
    end
    f[forward]-->c@{shape: text, label: "Template Cache"};

TEMPLATE_MODEL

TEMPLATE_MODEL implements the Template model logic, inheriting from torch.nn.Module with required process_inputs and forward methods. These two methods form the complete Template model inference process, split into two stages to better support two-stage split training.

  • process_inputs must use @torch.no_grad() for gradient-free computation
  • forward must contain all gradient computations required for training

Both methods should accept **kwargs for compatibility. Reserved parameters include:

  • To interact with the base model Pipeline (e.g., call text encoder), add pipe parameter to method inputs
  • To enable Gradient Checkpointing, add use_gradient_checkpointing and use_gradient_checkpointing_offload to forward inputs
  • Multiple Template models use model_id to distinguish Template Inputs - do not use this field in method parameters

TEMPLATE_MODEL_PATH (Optional)

TEMPLATE_MODEL_PATH specifies the relative path to pretrained weights. For example:

TEMPLATE_MODEL_PATH = "model.safetensors"

For multi-file models:

TEMPLATE_MODEL_PATH = [
    "model-00001-of-00003.safetensors",
    "model-00002-of-00003.safetensors",
    "model-00003-of-00003.safetensors",
]

Set to None for random initialization:

TEMPLATE_MODEL_PATH = None

TEMPLATE_DATA_PROCESSOR (Optional)

To train Template models with DiffSynth-Studio, datasets should contain template_inputs fields in metadata.json. These fields pass through TEMPLATE_DATA_PROCESSOR to generate inputs for Template model methods.

For example, the brightness control model DiffSynth-Studio/F2KB4B-Template-Brightness takes scale as input:

[
    {
        "image": "images/image_1.jpg",
        "prompt": "a cat",
        "template_inputs": {"scale": 0.2}
    },
    {
        "image": "images/image_2.jpg",
        "prompt": "a dog",
        "template_inputs": {"scale": 0.6}
    }
]
class DataProcessor:
    def __call__(self, scale, **kwargs):
        return {"scale": scale}

TEMPLATE_DATA_PROCESSOR = DataProcessor

Or calculate scale from image paths:

[
    {
        "image": "images/image_1.jpg",
        "prompt": "a cat",
        "template_inputs": {"image": "/path/to/your/dataset/images/image_1.jpg"}
    }
]
class DataProcessor:
    def __call__(self, image, **kwargs):
        image = Image.open(image)
        image = np.array(image)
        return {"scale": image.astype(np.float32).mean() / 255}

TEMPLATE_DATA_PROCESSOR = DataProcessor

Training Template Models

A Template model is "trainable" if its Template Cache variables are fully decoupled from the base model Pipeline - these variables should reach model_fn without participating in any Pipeline Unit calculations.

For training with black-forest-labs/FLUX.2-klein-base-4B, use these training script parameters:

  • --extra_inputs: Additional inputs. Use template_inputs for text-to-image models, edit_image,template_inputs for image editing models
  • --template_model_id_or_path: Template model ID or local path (use : suffix for ModelScope IDs, e.g., "DiffSynth-Studio/Template-KleinBase4B-Brightness:")
  • --remove_prefix_in_ckpt: State dict prefix to remove when saving models (use "pipe.template_model.")
  • --trainable_models: Trainable components (use "template_model" for full model, or "template_model.xxx,template_model.yyy" for specific components)

Example training script:

accelerate launch examples/flux2/model_training/train.py \
  --dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness \
  --dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness/metadata.jsonl \
  --extra_inputs "template_inputs" \
  --max_pixels 1048576 \
  --dataset_repeat 50 \
  --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
  --template_model_id_or_path "examples/flux2/model_training/scripts/brightness" \
  --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
  --learning_rate 1e-4 \
  --num_epochs 2 \
  --remove_prefix_in_ckpt "pipe.template_model." \
  --output_path "./models/train/Template-KleinBase4B-Brightness_example" \
  --trainable_models "template_model" \
  --use_gradient_checkpointing \
  --find_unused_parameters

Interacting with Base Model Pipeline Components

Template models can interact with base model Pipelines. For example, using the text encoder:

class CustomizedTemplateModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.xxx = xxx()

    @torch.no_grad()
    def process_inputs(self, text, pipe, **kwargs):
        input_ids = pipe.tokenizer(text)
        text_emb = pipe.text_encoder(input_ids)
        return {"text_emb": text_emb}

    def forward(self, text_emb, pipe, **kwargs):
        kv_cache = self.xxx(text_emb)
        return {"kv_cache": kv_cache}

TEMPLATE_MODEL = CustomizedTemplateModel

Using Non-Trainable Components

For models with pretrained components:

class CustomizedTemplateModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.image_encoder = XXXEncoder.from_pretrained(xxx)
        self.mlp = MLP()

    @torch.no_grad()
    def process_inputs(self, image, **kwargs):
        emb = self.image_encoder(image)
        return {"emb": emb}

    def forward(self, emb, **kwargs):
        kv_cache = self.mlp(emb)
        return {"kv_cache": kv_cache}

TEMPLATE_MODEL = CustomizedTemplateModel

Set --trainable_models template_model.mlp to train only the MLP component.

Uploading Template Models

After training, follow these steps to upload to ModelScope:

  1. Set model path in model.py:
TEMPLATE_MODEL_PATH = "model.safetensors"
  1. Upload using ModelScope CLI:
modelscope upload user_name/your_model_id /path/to/your/model.py model.py --token ms-xxx
  1. Package model files:
from diffsynth.diffusion.template import load_template_model, load_state_dict
from safetensors.torch import save_file
import torch

model = load_template_model("path/to/your/template/model", torch_dtype=torch.bfloat16, device="cpu")
state_dict = load_state_dict("path/to/your/ckpt/epoch-1.safetensors", torch_dtype=torch.bfloat16, device="cpu")
state_dict.update(model.state_dict())
save_file(state_dict, "model.safetensors")
  1. Upload model file:
modelscope upload user_name/your_model_id /path/to/your/model/epoch-1.safetensors model.safetensors --token ms-xxx
  1. Verify inference:
from diffsynth.diffusion.template import TemplatePipeline
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch

# Load base model
pipe = Flux2ImagePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
        ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
        ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
    ],
    tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)

# Load Template model
template_pipeline = TemplatePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="user_name/your_model_id")
    ],
)

# Generate image
image = template_pipeline(
    pipe,
    prompt="a cat",
    seed=0, cfg_scale=4,
    height=1024, width=1024,
    template_inputs=[{xxx}],
)
image.save("image.png")