10 KiB
Template Model Training
DiffSynth-Studio currently provides comprehensive Template training support for black-forest-labs/FLUX.2-klein-base-4B, with more model adaptations coming soon.
Continuing Training from Pretrained Models
To continue training from our pretrained models, refer to the table in FLUX.2 to find the corresponding training script.
Building New Template Models
Template Model Component Format
A Template model binds to a model repository (or local folder) containing a code file model.py as the entry point. Here's the template for model.py:
import torch
class CustomizedTemplateModel(torch.nn.Module):
def __init__(self):
super().__init__()
@torch.no_grad()
def process_inputs(self, xxx, **kwargs):
yyy = xxx
return {"yyy": yyy}
def forward(self, yyy, **kwargs):
zzz = yyy
return {"zzz": zzz}
class DataProcessor:
def __call__(self, www, **kwargs):
xxx = www
return {"xxx": xxx}
TEMPLATE_MODEL = CustomizedTemplateModel
TEMPLATE_MODEL_PATH = "model.safetensors"
TEMPLATE_DATA_PROCESSOR = DataProcessor
During Template model inference, Template Input passes through TEMPLATE_MODEL's process_inputs and forward to generate Template Cache.
flowchart LR;
i@{shape: text, label: "Template Input"}-->p[process_inputs];
subgraph TEMPLATE_MODEL
p[process_inputs]-->f[forward]
end
f[forward]-->c@{shape: text, label: "Template Cache"};
During Template model training, Template Input comes from the dataset through TEMPLATE_DATA_PROCESSOR.
flowchart LR;
d@{shape: text, label: "Dataset"}-->dp[TEMPLATE_DATA_PROCESSOR]-->p[process_inputs];
subgraph TEMPLATE_MODEL
p[process_inputs]-->f[forward]
end
f[forward]-->c@{shape: text, label: "Template Cache"};
TEMPLATE_MODEL
TEMPLATE_MODEL implements the Template model logic, inheriting from torch.nn.Module with required process_inputs and forward methods. These two methods form the complete Template model inference process, split into two stages to better support two-stage split training.
process_inputsmust use@torch.no_grad()for gradient-free computationforwardmust contain all gradient computations required for training
Both methods should accept **kwargs for compatibility. Reserved parameters include:
- To interact with the base model Pipeline (e.g., call text encoder), add
pipeparameter to method inputs - To enable Gradient Checkpointing, add
use_gradient_checkpointinganduse_gradient_checkpointing_offloadtoforwardinputs - Multiple Template models use
model_idto distinguish Template Inputs - do not use this field in method parameters
TEMPLATE_MODEL_PATH (Optional)
TEMPLATE_MODEL_PATH specifies the relative path to pretrained weights. For example:
TEMPLATE_MODEL_PATH = "model.safetensors"
For multi-file models:
TEMPLATE_MODEL_PATH = [
"model-00001-of-00003.safetensors",
"model-00002-of-00003.safetensors",
"model-00003-of-00003.safetensors",
]
Set to None for random initialization:
TEMPLATE_MODEL_PATH = None
TEMPLATE_DATA_PROCESSOR (Optional)
To train Template models with DiffSynth-Studio, datasets should contain template_inputs fields in metadata.json. These fields pass through TEMPLATE_DATA_PROCESSOR to generate inputs for Template model methods.
For example, the brightness control model DiffSynth-Studio/F2KB4B-Template-Brightness takes scale as input:
[
{
"image": "images/image_1.jpg",
"prompt": "a cat",
"template_inputs": {"scale": 0.2}
},
{
"image": "images/image_2.jpg",
"prompt": "a dog",
"template_inputs": {"scale": 0.6}
}
]
class DataProcessor:
def __call__(self, scale, **kwargs):
return {"scale": scale}
TEMPLATE_DATA_PROCESSOR = DataProcessor
Or calculate scale from image paths:
[
{
"image": "images/image_1.jpg",
"prompt": "a cat",
"template_inputs": {"image": "/path/to/your/dataset/images/image_1.jpg"}
}
]
class DataProcessor:
def __call__(self, image, **kwargs):
image = Image.open(image)
image = np.array(image)
return {"scale": image.astype(np.float32).mean() / 255}
TEMPLATE_DATA_PROCESSOR = DataProcessor
Training Template Models
A Template model is "trainable" if its Template Cache variables are fully decoupled from the base model Pipeline - these variables should reach model_fn without participating in any Pipeline Unit calculations.
For training with black-forest-labs/FLUX.2-klein-base-4B, use these training script parameters:
--extra_inputs: Additional inputs. Usetemplate_inputsfor text-to-image models,edit_image,template_inputsfor image editing models--template_model_id_or_path: Template model ID or local path (use:suffix for ModelScope IDs, e.g.,"DiffSynth-Studio/Template-KleinBase4B-Brightness:")--remove_prefix_in_ckpt: State dict prefix to remove when saving models (use"pipe.template_model.")--trainable_models: Trainable components (use"template_model"for full model, or"template_model.xxx,template_model.yyy"for specific components)
Example training script:
accelerate launch examples/flux2/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness \
--dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness/metadata.jsonl \
--extra_inputs "template_inputs" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
--template_model_id_or_path "examples/flux2/model_training/scripts/brightness" \
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
--learning_rate 1e-4 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.template_model." \
--output_path "./models/train/Template-KleinBase4B-Brightness_example" \
--trainable_models "template_model" \
--use_gradient_checkpointing \
--find_unused_parameters
Interacting with Base Model Pipeline Components
Template models can interact with base model Pipelines. For example, using the text encoder:
class CustomizedTemplateModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.xxx = xxx()
@torch.no_grad()
def process_inputs(self, text, pipe, **kwargs):
input_ids = pipe.tokenizer(text)
text_emb = pipe.text_encoder(input_ids)
return {"text_emb": text_emb}
def forward(self, text_emb, pipe, **kwargs):
kv_cache = self.xxx(text_emb)
return {"kv_cache": kv_cache}
TEMPLATE_MODEL = CustomizedTemplateModel
Using Non-Trainable Components
For models with pretrained components:
class CustomizedTemplateModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.image_encoder = XXXEncoder.from_pretrained(xxx)
self.mlp = MLP()
@torch.no_grad()
def process_inputs(self, image, **kwargs):
emb = self.image_encoder(image)
return {"emb": emb}
def forward(self, emb, **kwargs):
kv_cache = self.mlp(emb)
return {"kv_cache": kv_cache}
TEMPLATE_MODEL = CustomizedTemplateModel
Set --trainable_models template_model.mlp to train only the MLP component.
Uploading Template Models
After training, follow these steps to upload to ModelScope:
- Set model path in
model.py:
TEMPLATE_MODEL_PATH = "model.safetensors"
- Upload using ModelScope CLI:
modelscope upload user_name/your_model_id /path/to/your/model.py model.py --token ms-xxx
- Package model files:
from diffsynth.diffusion.template import load_template_model, load_state_dict
from safetensors.torch import save_file
import torch
model = load_template_model("path/to/your/template/model", torch_dtype=torch.bfloat16, device="cpu")
state_dict = load_state_dict("path/to/your/ckpt/epoch-1.safetensors", torch_dtype=torch.bfloat16, device="cpu")
state_dict.update(model.state_dict())
save_file(state_dict, "model.safetensors")
- Upload model file:
modelscope upload user_name/your_model_id /path/to/your/model/epoch-1.safetensors model.safetensors --token ms-xxx
- Verify inference:
from diffsynth.diffusion.template import TemplatePipeline
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
import torch
# Load base model
pipe = Flux2ImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
)
# Load Template model
template_pipeline = TemplatePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="user_name/your_model_id")
],
)
# Generate image
image = template_pipeline(
pipe,
prompt="a cat",
seed=0, cfg_scale=4,
height=1024, width=1024,
template_inputs=[{xxx}],
)
image.save("image.png")