mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-16 15:28:21 +00:00
update docs
This commit is contained in:
297
docs/en/Diffusion_Templates/Template_Model_Training.md
Normal file
297
docs/en/Diffusion_Templates/Template_Model_Training.md
Normal file
@@ -0,0 +1,297 @@
|
||||
# Template Model Training
|
||||
|
||||
DiffSynth-Studio currently provides comprehensive Template training support for [black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B), with more model adaptations coming soon.
|
||||
|
||||
## Continuing Training from Pretrained Models
|
||||
|
||||
To continue training from our pretrained models, refer to the table in [FLUX.2](../Model_Details/FLUX2.md#model-overview) to find the corresponding training script.
|
||||
|
||||
## Building New Template Models
|
||||
|
||||
### Template Model Component Format
|
||||
|
||||
A Template model binds to a model repository (or local folder) containing a code file `model.py` as the entry point. Here's the template for `model.py`:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class CustomizedTemplateModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.no_grad()
|
||||
def process_inputs(self, xxx, **kwargs):
|
||||
yyy = xxx
|
||||
return {"yyy": yyy}
|
||||
|
||||
def forward(self, yyy, **kwargs):
|
||||
zzz = yyy
|
||||
return {"zzz": zzz}
|
||||
|
||||
class DataProcessor:
|
||||
def __call__(self, www, **kwargs):
|
||||
xxx = www
|
||||
return {"xxx": xxx}
|
||||
|
||||
TEMPLATE_MODEL = CustomizedTemplateModel
|
||||
TEMPLATE_MODEL_PATH = "model.safetensors"
|
||||
TEMPLATE_DATA_PROCESSOR = DataProcessor
|
||||
```
|
||||
|
||||
During Template model inference, Template Input passes through `TEMPLATE_MODEL`'s `process_inputs` and `forward` to generate Template Cache.
|
||||
|
||||
```mermaid
|
||||
flowchart LR;
|
||||
i@{shape: text, label: "Template Input"}-->p[process_inputs];
|
||||
subgraph TEMPLATE_MODEL
|
||||
p[process_inputs]-->f[forward]
|
||||
end
|
||||
f[forward]-->c@{shape: text, label: "Template Cache"};
|
||||
```
|
||||
|
||||
During Template model training, Template Input comes from the dataset through `TEMPLATE_DATA_PROCESSOR`.
|
||||
|
||||
```mermaid
|
||||
flowchart LR;
|
||||
d@{shape: text, label: "Dataset"}-->dp[TEMPLATE_DATA_PROCESSOR]-->p[process_inputs];
|
||||
subgraph TEMPLATE_MODEL
|
||||
p[process_inputs]-->f[forward]
|
||||
end
|
||||
f[forward]-->c@{shape: text, label: "Template Cache"};
|
||||
```
|
||||
|
||||
#### `TEMPLATE_MODEL`
|
||||
|
||||
`TEMPLATE_MODEL` implements the Template model logic, inheriting from `torch.nn.Module` with required `process_inputs` and `forward` methods. These two methods form the complete Template model inference process, split into two stages to better support [two-stage split training](https://diffsynth-studio-doc.readthedocs.io/en/latest/Training/Split_Training.html).
|
||||
|
||||
* `process_inputs` must use `@torch.no_grad()` for gradient-free computation
|
||||
* `forward` must contain all gradient computations required for training
|
||||
|
||||
Both methods should accept `**kwargs` for compatibility. Reserved parameters include:
|
||||
|
||||
* To interact with the base model Pipeline (e.g., call text encoder), add `pipe` parameter to method inputs
|
||||
* To enable Gradient Checkpointing, add `use_gradient_checkpointing` and `use_gradient_checkpointing_offload` to `forward` inputs
|
||||
* Multiple Template models use `model_id` to distinguish Template Inputs - do not use this field in method parameters
|
||||
|
||||
#### `TEMPLATE_MODEL_PATH` (Optional)
|
||||
|
||||
`TEMPLATE_MODEL_PATH` specifies the relative path to pretrained weights. For example:
|
||||
|
||||
```python
|
||||
TEMPLATE_MODEL_PATH = "model.safetensors"
|
||||
```
|
||||
|
||||
For multi-file models:
|
||||
|
||||
```python
|
||||
TEMPLATE_MODEL_PATH = [
|
||||
"model-00001-of-00003.safetensors",
|
||||
"model-00002-of-00003.safetensors",
|
||||
"model-00003-of-00003.safetensors",
|
||||
]
|
||||
```
|
||||
|
||||
Set to `None` for random initialization:
|
||||
|
||||
```python
|
||||
TEMPLATE_MODEL_PATH = None
|
||||
```
|
||||
|
||||
#### `TEMPLATE_DATA_PROCESSOR` (Optional)
|
||||
|
||||
To train Template models with DiffSynth-Studio, datasets should contain `template_inputs` fields in `metadata.json`. These fields pass through `TEMPLATE_DATA_PROCESSOR` to generate inputs for Template model methods.
|
||||
|
||||
For example, the brightness control model [DiffSynth-Studio/F2KB4B-Template-Brightness](https://modelscope.cn/models/DiffSynth-Studio/F2KB4B-Template-Brightness) takes `scale` as input:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"image": "images/image_1.jpg",
|
||||
"prompt": "a cat",
|
||||
"template_inputs": {"scale": 0.2}
|
||||
},
|
||||
{
|
||||
"image": "images/image_2.jpg",
|
||||
"prompt": "a dog",
|
||||
"template_inputs": {"scale": 0.6}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
```python
|
||||
class DataProcessor:
|
||||
def __call__(self, scale, **kwargs):
|
||||
return {"scale": scale}
|
||||
|
||||
TEMPLATE_DATA_PROCESSOR = DataProcessor
|
||||
```
|
||||
|
||||
Or calculate scale from image paths:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"image": "images/image_1.jpg",
|
||||
"prompt": "a cat",
|
||||
"template_inputs": {"image": "/path/to/your/dataset/images/image_1.jpg"}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
```python
|
||||
class DataProcessor:
|
||||
def __call__(self, image, **kwargs):
|
||||
image = Image.open(image)
|
||||
image = np.array(image)
|
||||
return {"scale": image.astype(np.float32).mean() / 255}
|
||||
|
||||
TEMPLATE_DATA_PROCESSOR = DataProcessor
|
||||
```
|
||||
|
||||
### Training Template Models
|
||||
|
||||
A Template model is "trainable" if its Template Cache variables are fully decoupled from the base model Pipeline - these variables should reach `model_fn` without participating in any Pipeline Unit calculations.
|
||||
|
||||
For training with [black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B), use these training script parameters:
|
||||
|
||||
* `--extra_inputs`: Additional inputs. Use `template_inputs` for text-to-image models, `edit_image,template_inputs` for image editing models
|
||||
* `--template_model_id_or_path`: Template model ID or local path (use `:` suffix for ModelScope IDs, e.g., `"DiffSynth-Studio/Template-KleinBase4B-Brightness:"`)
|
||||
* `--remove_prefix_in_ckpt`: State dict prefix to remove when saving models (use `"pipe.template_model."`)
|
||||
* `--trainable_models`: Trainable components (use `"template_model"` for full model, or `"template_model.xxx,template_model.yyy"` for specific components)
|
||||
|
||||
Example training script:
|
||||
|
||||
```shell
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness \
|
||||
--dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness/metadata.jsonl \
|
||||
--extra_inputs "template_inputs" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--template_model_id_or_path "examples/flux2/model_training/scripts/brightness" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.template_model." \
|
||||
--output_path "./models/train/Template-KleinBase4B-Brightness_example" \
|
||||
--trainable_models "template_model" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
```
|
||||
|
||||
### Interacting with Base Model Pipeline Components
|
||||
|
||||
Template models can interact with base model Pipelines. For example, using the text encoder:
|
||||
|
||||
```python
|
||||
class CustomizedTemplateModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.xxx = xxx()
|
||||
|
||||
@torch.no_grad()
|
||||
def process_inputs(self, text, pipe, **kwargs):
|
||||
input_ids = pipe.tokenizer(text)
|
||||
text_emb = pipe.text_encoder(input_ids)
|
||||
return {"text_emb": text_emb}
|
||||
|
||||
def forward(self, text_emb, pipe, **kwargs):
|
||||
kv_cache = self.xxx(text_emb)
|
||||
return {"kv_cache": kv_cache}
|
||||
|
||||
TEMPLATE_MODEL = CustomizedTemplateModel
|
||||
```
|
||||
|
||||
### Using Non-Trainable Components
|
||||
|
||||
For models with pretrained components:
|
||||
|
||||
```python
|
||||
class CustomizedTemplateModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.image_encoder = XXXEncoder.from_pretrained(xxx)
|
||||
self.mlp = MLP()
|
||||
|
||||
@torch.no_grad()
|
||||
def process_inputs(self, image, **kwargs):
|
||||
emb = self.image_encoder(image)
|
||||
return {"emb": emb}
|
||||
|
||||
def forward(self, emb, **kwargs):
|
||||
kv_cache = self.mlp(emb)
|
||||
return {"kv_cache": kv_cache}
|
||||
|
||||
TEMPLATE_MODEL = CustomizedTemplateModel
|
||||
```
|
||||
|
||||
Set `--trainable_models template_model.mlp` to train only the MLP component.
|
||||
|
||||
### Uploading Template Models
|
||||
|
||||
After training, follow these steps to upload to ModelScope:
|
||||
|
||||
1. Set model path in `model.py`:
|
||||
```python
|
||||
TEMPLATE_MODEL_PATH = "model.safetensors"
|
||||
```
|
||||
|
||||
2. Upload using ModelScope CLI:
|
||||
```shell
|
||||
modelscope upload user_name/your_model_id /path/to/your/model.py model.py --token ms-xxx
|
||||
```
|
||||
|
||||
3. Package model files:
|
||||
```python
|
||||
from diffsynth.diffusion.template import load_template_model, load_state_dict
|
||||
from safetensors.torch import save_file
|
||||
import torch
|
||||
|
||||
model = load_template_model("path/to/your/template/model", torch_dtype=torch.bfloat16, device="cpu")
|
||||
state_dict = load_state_dict("path/to/your/ckpt/epoch-1.safetensors", torch_dtype=torch.bfloat16, device="cpu")
|
||||
state_dict.update(model.state_dict())
|
||||
save_file(state_dict, "model.safetensors")
|
||||
```
|
||||
|
||||
4. Upload model file:
|
||||
```shell
|
||||
modelscope upload user_name/your_model_id /path/to/your/model/epoch-1.safetensors model.safetensors --token ms-xxx
|
||||
```
|
||||
|
||||
5. Verify inference:
|
||||
```python
|
||||
from diffsynth.diffusion.template import TemplatePipeline
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
# Load Template model
|
||||
template_pipeline = TemplatePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="user_name/your_model_id")
|
||||
],
|
||||
)
|
||||
|
||||
# Generate image
|
||||
image = template_pipeline(
|
||||
pipe,
|
||||
prompt="a cat",
|
||||
seed=0, cfg_scale=4,
|
||||
height=1024, width=1024,
|
||||
template_inputs=[{xxx}],
|
||||
)
|
||||
image.save("image.png")
|
||||
Reference in New Issue
Block a user