mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-16 23:38:19 +00:00
update docs
This commit is contained in:
330
docs/en/Diffusion_Templates/Template_Model_Inference.md
Normal file
330
docs/en/Diffusion_Templates/Template_Model_Inference.md
Normal file
@@ -0,0 +1,330 @@
|
||||
# Template Model Inference
|
||||
|
||||
## Enabling Template Models on Base Model Pipelines
|
||||
|
||||
Using the base model [black-forest-labs/FLUX.2-klein-base-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B) as an example, when generating images using only the base model:
|
||||
|
||||
```python
|
||||
from diffsynth.diffusion.template import TemplatePipeline
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
# Generate an image
|
||||
image = pipe(
|
||||
prompt="a cat",
|
||||
seed=0, cfg_scale=4,
|
||||
height=1024, width=1024,
|
||||
)
|
||||
image.save("image.png")
|
||||
```
|
||||
|
||||
The Template model [DiffSynth-Studio/F2KB4B-Template-Brightness](https://modelscope.cn/models/DiffSynth-Studio/F2KB4B-Template-Brightness) can control image brightness during generation. Through the `TemplatePipeline` model, it can be loaded from ModelScope (via `ModelConfig(model_id="xxx/xxx")`) or from a local path (via `ModelConfig(path="xxx")`). Inputting `scale=0.8` increases image brightness. Note that in the code, input parameters for `pipe` must be transferred to `template_pipeline`, and `template_inputs` should be added.
|
||||
|
||||
```python
|
||||
# Load Template model
|
||||
template_pipeline = TemplatePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/F2KB4B-Template-Brightness")
|
||||
],
|
||||
)
|
||||
# Generate an image
|
||||
image = template_pipeline(
|
||||
pipe,
|
||||
prompt="a cat",
|
||||
seed=0, cfg_scale=4,
|
||||
height=1024, width=1024,
|
||||
template_inputs=[{"scale": 0.8}],
|
||||
)
|
||||
image.save("image_0.8.png")
|
||||
```
|
||||
|
||||
## CFG Enhancement for Template Models
|
||||
|
||||
Template models can enable CFG (Classifier-Free Guidance) to make control effects more pronounced. For example, with the model [DiffSynth-Studio/F2KB4B-Template-Brightness](https://modelscope.cn/models/DiffSynth-Studio/F2KB4B-Template-Brightness), adding `negative_template_inputs` to the TemplatePipeline input parameters and setting its scale to 0.5 will generate images with more noticeable brightness variations by contrasting both sides.
|
||||
|
||||
```python
|
||||
# Generate an image with CFG
|
||||
image = template_pipeline(
|
||||
pipe,
|
||||
prompt="a cat",
|
||||
seed=0, cfg_scale=4,
|
||||
height=1024, width=1024,
|
||||
template_inputs=[{"scale": 0.8}],
|
||||
negative_template_inputs=[{"scale": 0.5}],
|
||||
)
|
||||
image.save("image_0.8_cfg.png")
|
||||
```
|
||||
|
||||
## Low VRAM Support
|
||||
|
||||
Template models currently do not support the main framework's VRAM management, but lazy loading can be used - loading Template models only when needed for inference. This significantly reduces VRAM requirements when enabling multiple Template models, with peak VRAM usage being that of a single Template model. Add parameter `lazy_loading=True` to enable.
|
||||
|
||||
```python
|
||||
template_pipeline = TemplatePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/F2KB4B-Template-Brightness")
|
||||
],
|
||||
lazy_loading=True,
|
||||
)
|
||||
```
|
||||
|
||||
The base model's Pipeline and Template Pipeline are completely independent and can enable VRAM management on demand.
|
||||
|
||||
When Template model outputs contain LoRA in Template Cache, you need to enable VRAM management for the base model's Pipeline or enable LoRA hot loading (using the code below), otherwise LoRA weights will be叠加.
|
||||
|
||||
```python
|
||||
pipe.dit = pipe.enable_lora_hot_loading(pipe.dit)
|
||||
```
|
||||
|
||||
## Enabling Multiple Template Models
|
||||
|
||||
`TemplatePipeline` can load multiple Template models. During inference, use `model_id` in `template_inputs` to distinguish inputs for each Template model.
|
||||
|
||||
After enabling VRAM management for the base model's Pipeline and lazy loading for Template Pipeline, you can load any number of Template models.
|
||||
|
||||
```python
|
||||
from diffsynth.diffusion.template import TemplatePipeline
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.dit = pipe.enable_lora_hot_loading(pipe.dit)
|
||||
template = TemplatePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
lazy_loading=True,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Brightness"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-ControlNet"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Edit"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Upscaler"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-SoftRGB"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Sharpness"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Inpaint"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-Aesthetic"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/Template-KleinBase4B-PandaMeme"),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
### Super-Resolution + Sharpness Enhancement
|
||||
|
||||
Combining [DiffSynth-Studio/Template-KleinBase4B-Upscaler](https://modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Upscaler) and [DiffSynth-Studio/Template-KleinBase4B-Sharpness](https://modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Sharpness) can upscale blurry images while improving detail clarity.
|
||||
|
||||
```python
|
||||
image = template(
|
||||
pipe,
|
||||
prompt="A cat is sitting on a stone.",
|
||||
seed=0, cfg_scale=4, num_inference_steps=50,
|
||||
template_inputs = [
|
||||
{
|
||||
"model_id": 3,
|
||||
"image": Image.open("data/assets/image_lowres_100.jpg"),
|
||||
"prompt": "A cat is sitting on a stone.",
|
||||
},
|
||||
{
|
||||
"model_id": 5,
|
||||
"scale": 1,
|
||||
},
|
||||
],
|
||||
negative_template_inputs = [
|
||||
{
|
||||
"model_id": 3,
|
||||
"image": Image.open("data/assets/image_lowres_100.jpg"),
|
||||
"prompt": "",
|
||||
},
|
||||
{
|
||||
"model_id": 5,
|
||||
"scale": 0,
|
||||
},
|
||||
],
|
||||
)
|
||||
image.save("image_Upscaler_Sharpness.png")
|
||||
```
|
||||
|
||||
| Low Resolution Input | High Resolution Output |
|
||||
|----------------------|------------------------|
|
||||
|  |  |
|
||||
|
||||
### Structure Control + Aesthetic Alignment + Sharpness Enhancement
|
||||
|
||||
[DiffSynth-Studio/Template-KleinBase4B-ControlNet](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ControlNet) controls composition, [DiffSynth-Studio/Template-KleinBase4B-Aesthetic](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Aesthetic) fills in details, and [DiffSynth-Studio/Template-KleinBase4B-Sharpness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Sharpness) ensures clarity. Combining these three Template models produces exquisite images.
|
||||
|
||||
```python
|
||||
image = template(
|
||||
pipe,
|
||||
prompt="A cat is sitting on a stone, bathed in bright sunshine.",
|
||||
seed=0, cfg_scale=4, num_inference_steps=50,
|
||||
template_inputs = [
|
||||
{
|
||||
"model_id": 1,
|
||||
"image": Image.open("data/assets/image_depth.jpg"),
|
||||
"prompt": "A cat is sitting on a stone, bathed in bright sunshine.",
|
||||
},
|
||||
{
|
||||
"model_id": 7,
|
||||
"lora_ids": list(range(1, 180, 2)),
|
||||
"lora_scales": 2.0,
|
||||
"merge_type": "mean",
|
||||
},
|
||||
{
|
||||
"model_id": 5,
|
||||
"scale": 0.8,
|
||||
},
|
||||
],
|
||||
negative_template_inputs = [
|
||||
{
|
||||
"model_id": 1,
|
||||
"image": Image.open("data/assets/image_depth.jpg"),
|
||||
"prompt": "",
|
||||
},
|
||||
{
|
||||
"model_id": 7,
|
||||
"lora_ids": list(range(1, 180, 2)),
|
||||
"lora_scales": 2.0,
|
||||
"merge_type": "mean",
|
||||
},
|
||||
{
|
||||
"model_id": 5,
|
||||
"scale": 0,
|
||||
},
|
||||
],
|
||||
)
|
||||
image.save("image_Controlnet_Aesthetic_Sharpness.png")
|
||||
```
|
||||
|
||||
| Structure Control Image | Output Image |
|
||||
|-------------------------|--------------|
|
||||
|  |  |
|
||||
|
||||
### Structure Control + Image Editing + Color Adjustment
|
||||
|
||||
[DiffSynth-Studio/Template-KleinBase4B-ControlNet](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ControlNet) controls composition, [DiffSynth-Studio/Template-KleinBase4B-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Edit) preserves original image details like fur texture, and [DiffSynth-Studio/Template-KleinBase4B-SoftRGB](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-SoftRGB) controls color tones, creating an artistic masterpiece.
|
||||
|
||||
```python
|
||||
image = template(
|
||||
pipe,
|
||||
prompt="A cat is sitting on a stone. Colored ink painting.",
|
||||
seed=0, cfg_scale=4, num_inference_steps=50,
|
||||
template_inputs = [
|
||||
{
|
||||
"model_id": 1,
|
||||
"image": Image.open("data/assets/image_depth.jpg"),
|
||||
"prompt": "A cat is sitting on a stone. Colored ink painting.",
|
||||
},
|
||||
{
|
||||
"model_id": 2,
|
||||
"image": Image.open("data/assets/image_reference.jpg"),
|
||||
"prompt": "Convert the image style to colored ink painting.",
|
||||
},
|
||||
{
|
||||
"model_id": 4,
|
||||
"R": 0.9,
|
||||
"G": 0.5,
|
||||
"B": 0.3,
|
||||
},
|
||||
],
|
||||
negative_template_inputs = [
|
||||
{
|
||||
"model_id": 1,
|
||||
"image": Image.open("data/assets/image_depth.jpg"),
|
||||
"prompt": "",
|
||||
},
|
||||
{
|
||||
"model_id": 2,
|
||||
"image": Image.open("data/assets/image_reference.jpg"),
|
||||
"prompt": "",
|
||||
},
|
||||
],
|
||||
)
|
||||
image.save("image_Controlnet_Edit_SoftRGB.png")
|
||||
```
|
||||
|
||||
| Structure Control Image | Editing Input Image | Output Image |
|
||||
|-------------------------|---------------------|--------------|
|
||||
|  |  |  |
|
||||
|
||||
### Brightness Control + Image Editing + Local Redrawing
|
||||
|
||||
[DiffSynth-Studio/Template-KleinBase4B-Brightness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Brightness) generates bright scenes, [DiffSynth-Studio/Template-KleinBase4B-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Edit) references original image layout, and [DiffSynth-Studio/Template-KleinBase4B-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Inpaint) keeps background unchanged, generating cross-dimensional content.
|
||||
|
||||
```python
|
||||
image = template(
|
||||
pipe,
|
||||
prompt="A cat is sitting on a stone. Flat anime style.",
|
||||
seed=0, cfg_scale=4, num_inference_steps=50,
|
||||
template_inputs = [
|
||||
{
|
||||
"model_id": 0,
|
||||
"scale": 0.6,
|
||||
},
|
||||
{
|
||||
"model_id": 2,
|
||||
"image": Image.open("data/assets/image_reference.jpg"),
|
||||
"prompt": "Convert the image style to flat anime style.",
|
||||
},
|
||||
{
|
||||
"model_id": 6,
|
||||
"image": Image.open("data/assets/image_reference.jpg"),
|
||||
"mask": Image.open("data/assets/image_mask_1.jpg"),
|
||||
"force_inpaint": True,
|
||||
},
|
||||
],
|
||||
negative_template_inputs = [
|
||||
{
|
||||
"model_id": 0,
|
||||
"scale": 0.5,
|
||||
},
|
||||
{
|
||||
"model_id": 2,
|
||||
"image": Image.open("data/assets/image_reference.jpg"),
|
||||
"prompt": "",
|
||||
},
|
||||
{
|
||||
"model_id": 6,
|
||||
"image": Image.open("data/assets/image_reference.jpg"),
|
||||
"mask": Image.open("data/assets/image_mask_1.jpg"),
|
||||
},
|
||||
],
|
||||
)
|
||||
image.save("image_Brightness_Edit_Inpaint.png")
|
||||
```
|
||||
|
||||
| Reference Image | Redrawing Area | Output Image |
|
||||
|------------------|----------------|--------------|
|
||||
|  |  |  |
|
||||
297
docs/en/Diffusion_Templates/Template_Model_Training.md
Normal file
297
docs/en/Diffusion_Templates/Template_Model_Training.md
Normal file
@@ -0,0 +1,297 @@
|
||||
# Template Model Training
|
||||
|
||||
DiffSynth-Studio currently provides comprehensive Template training support for [black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B), with more model adaptations coming soon.
|
||||
|
||||
## Continuing Training from Pretrained Models
|
||||
|
||||
To continue training from our pretrained models, refer to the table in [FLUX.2](../Model_Details/FLUX2.md#model-overview) to find the corresponding training script.
|
||||
|
||||
## Building New Template Models
|
||||
|
||||
### Template Model Component Format
|
||||
|
||||
A Template model binds to a model repository (or local folder) containing a code file `model.py` as the entry point. Here's the template for `model.py`:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class CustomizedTemplateModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.no_grad()
|
||||
def process_inputs(self, xxx, **kwargs):
|
||||
yyy = xxx
|
||||
return {"yyy": yyy}
|
||||
|
||||
def forward(self, yyy, **kwargs):
|
||||
zzz = yyy
|
||||
return {"zzz": zzz}
|
||||
|
||||
class DataProcessor:
|
||||
def __call__(self, www, **kwargs):
|
||||
xxx = www
|
||||
return {"xxx": xxx}
|
||||
|
||||
TEMPLATE_MODEL = CustomizedTemplateModel
|
||||
TEMPLATE_MODEL_PATH = "model.safetensors"
|
||||
TEMPLATE_DATA_PROCESSOR = DataProcessor
|
||||
```
|
||||
|
||||
During Template model inference, Template Input passes through `TEMPLATE_MODEL`'s `process_inputs` and `forward` to generate Template Cache.
|
||||
|
||||
```mermaid
|
||||
flowchart LR;
|
||||
i@{shape: text, label: "Template Input"}-->p[process_inputs];
|
||||
subgraph TEMPLATE_MODEL
|
||||
p[process_inputs]-->f[forward]
|
||||
end
|
||||
f[forward]-->c@{shape: text, label: "Template Cache"};
|
||||
```
|
||||
|
||||
During Template model training, Template Input comes from the dataset through `TEMPLATE_DATA_PROCESSOR`.
|
||||
|
||||
```mermaid
|
||||
flowchart LR;
|
||||
d@{shape: text, label: "Dataset"}-->dp[TEMPLATE_DATA_PROCESSOR]-->p[process_inputs];
|
||||
subgraph TEMPLATE_MODEL
|
||||
p[process_inputs]-->f[forward]
|
||||
end
|
||||
f[forward]-->c@{shape: text, label: "Template Cache"};
|
||||
```
|
||||
|
||||
#### `TEMPLATE_MODEL`
|
||||
|
||||
`TEMPLATE_MODEL` implements the Template model logic, inheriting from `torch.nn.Module` with required `process_inputs` and `forward` methods. These two methods form the complete Template model inference process, split into two stages to better support [two-stage split training](https://diffsynth-studio-doc.readthedocs.io/en/latest/Training/Split_Training.html).
|
||||
|
||||
* `process_inputs` must use `@torch.no_grad()` for gradient-free computation
|
||||
* `forward` must contain all gradient computations required for training
|
||||
|
||||
Both methods should accept `**kwargs` for compatibility. Reserved parameters include:
|
||||
|
||||
* To interact with the base model Pipeline (e.g., call text encoder), add `pipe` parameter to method inputs
|
||||
* To enable Gradient Checkpointing, add `use_gradient_checkpointing` and `use_gradient_checkpointing_offload` to `forward` inputs
|
||||
* Multiple Template models use `model_id` to distinguish Template Inputs - do not use this field in method parameters
|
||||
|
||||
#### `TEMPLATE_MODEL_PATH` (Optional)
|
||||
|
||||
`TEMPLATE_MODEL_PATH` specifies the relative path to pretrained weights. For example:
|
||||
|
||||
```python
|
||||
TEMPLATE_MODEL_PATH = "model.safetensors"
|
||||
```
|
||||
|
||||
For multi-file models:
|
||||
|
||||
```python
|
||||
TEMPLATE_MODEL_PATH = [
|
||||
"model-00001-of-00003.safetensors",
|
||||
"model-00002-of-00003.safetensors",
|
||||
"model-00003-of-00003.safetensors",
|
||||
]
|
||||
```
|
||||
|
||||
Set to `None` for random initialization:
|
||||
|
||||
```python
|
||||
TEMPLATE_MODEL_PATH = None
|
||||
```
|
||||
|
||||
#### `TEMPLATE_DATA_PROCESSOR` (Optional)
|
||||
|
||||
To train Template models with DiffSynth-Studio, datasets should contain `template_inputs` fields in `metadata.json`. These fields pass through `TEMPLATE_DATA_PROCESSOR` to generate inputs for Template model methods.
|
||||
|
||||
For example, the brightness control model [DiffSynth-Studio/F2KB4B-Template-Brightness](https://modelscope.cn/models/DiffSynth-Studio/F2KB4B-Template-Brightness) takes `scale` as input:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"image": "images/image_1.jpg",
|
||||
"prompt": "a cat",
|
||||
"template_inputs": {"scale": 0.2}
|
||||
},
|
||||
{
|
||||
"image": "images/image_2.jpg",
|
||||
"prompt": "a dog",
|
||||
"template_inputs": {"scale": 0.6}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
```python
|
||||
class DataProcessor:
|
||||
def __call__(self, scale, **kwargs):
|
||||
return {"scale": scale}
|
||||
|
||||
TEMPLATE_DATA_PROCESSOR = DataProcessor
|
||||
```
|
||||
|
||||
Or calculate scale from image paths:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"image": "images/image_1.jpg",
|
||||
"prompt": "a cat",
|
||||
"template_inputs": {"image": "/path/to/your/dataset/images/image_1.jpg"}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
```python
|
||||
class DataProcessor:
|
||||
def __call__(self, image, **kwargs):
|
||||
image = Image.open(image)
|
||||
image = np.array(image)
|
||||
return {"scale": image.astype(np.float32).mean() / 255}
|
||||
|
||||
TEMPLATE_DATA_PROCESSOR = DataProcessor
|
||||
```
|
||||
|
||||
### Training Template Models
|
||||
|
||||
A Template model is "trainable" if its Template Cache variables are fully decoupled from the base model Pipeline - these variables should reach `model_fn` without participating in any Pipeline Unit calculations.
|
||||
|
||||
For training with [black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B), use these training script parameters:
|
||||
|
||||
* `--extra_inputs`: Additional inputs. Use `template_inputs` for text-to-image models, `edit_image,template_inputs` for image editing models
|
||||
* `--template_model_id_or_path`: Template model ID or local path (use `:` suffix for ModelScope IDs, e.g., `"DiffSynth-Studio/Template-KleinBase4B-Brightness:"`)
|
||||
* `--remove_prefix_in_ckpt`: State dict prefix to remove when saving models (use `"pipe.template_model."`)
|
||||
* `--trainable_models`: Trainable components (use `"template_model"` for full model, or `"template_model.xxx,template_model.yyy"` for specific components)
|
||||
|
||||
Example training script:
|
||||
|
||||
```shell
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness \
|
||||
--dataset_metadata_path data/diffsynth_example_dataset/flux2/Template-KleinBase4B-Brightness/metadata.jsonl \
|
||||
--extra_inputs "template_inputs" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--template_model_id_or_path "examples/flux2/model_training/scripts/brightness" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.template_model." \
|
||||
--output_path "./models/train/Template-KleinBase4B-Brightness_example" \
|
||||
--trainable_models "template_model" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
```
|
||||
|
||||
### Interacting with Base Model Pipeline Components
|
||||
|
||||
Template models can interact with base model Pipelines. For example, using the text encoder:
|
||||
|
||||
```python
|
||||
class CustomizedTemplateModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.xxx = xxx()
|
||||
|
||||
@torch.no_grad()
|
||||
def process_inputs(self, text, pipe, **kwargs):
|
||||
input_ids = pipe.tokenizer(text)
|
||||
text_emb = pipe.text_encoder(input_ids)
|
||||
return {"text_emb": text_emb}
|
||||
|
||||
def forward(self, text_emb, pipe, **kwargs):
|
||||
kv_cache = self.xxx(text_emb)
|
||||
return {"kv_cache": kv_cache}
|
||||
|
||||
TEMPLATE_MODEL = CustomizedTemplateModel
|
||||
```
|
||||
|
||||
### Using Non-Trainable Components
|
||||
|
||||
For models with pretrained components:
|
||||
|
||||
```python
|
||||
class CustomizedTemplateModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.image_encoder = XXXEncoder.from_pretrained(xxx)
|
||||
self.mlp = MLP()
|
||||
|
||||
@torch.no_grad()
|
||||
def process_inputs(self, image, **kwargs):
|
||||
emb = self.image_encoder(image)
|
||||
return {"emb": emb}
|
||||
|
||||
def forward(self, emb, **kwargs):
|
||||
kv_cache = self.mlp(emb)
|
||||
return {"kv_cache": kv_cache}
|
||||
|
||||
TEMPLATE_MODEL = CustomizedTemplateModel
|
||||
```
|
||||
|
||||
Set `--trainable_models template_model.mlp` to train only the MLP component.
|
||||
|
||||
### Uploading Template Models
|
||||
|
||||
After training, follow these steps to upload to ModelScope:
|
||||
|
||||
1. Set model path in `model.py`:
|
||||
```python
|
||||
TEMPLATE_MODEL_PATH = "model.safetensors"
|
||||
```
|
||||
|
||||
2. Upload using ModelScope CLI:
|
||||
```shell
|
||||
modelscope upload user_name/your_model_id /path/to/your/model.py model.py --token ms-xxx
|
||||
```
|
||||
|
||||
3. Package model files:
|
||||
```python
|
||||
from diffsynth.diffusion.template import load_template_model, load_state_dict
|
||||
from safetensors.torch import save_file
|
||||
import torch
|
||||
|
||||
model = load_template_model("path/to/your/template/model", torch_dtype=torch.bfloat16, device="cpu")
|
||||
state_dict = load_state_dict("path/to/your/ckpt/epoch-1.safetensors", torch_dtype=torch.bfloat16, device="cpu")
|
||||
state_dict.update(model.state_dict())
|
||||
save_file(state_dict, "model.safetensors")
|
||||
```
|
||||
|
||||
4. Upload model file:
|
||||
```shell
|
||||
modelscope upload user_name/your_model_id /path/to/your/model/epoch-1.safetensors model.safetensors --token ms-xxx
|
||||
```
|
||||
|
||||
5. Verify inference:
|
||||
```python
|
||||
from diffsynth.diffusion.template import TemplatePipeline
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
# Load Template model
|
||||
template_pipeline = TemplatePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="user_name/your_model_id")
|
||||
],
|
||||
)
|
||||
|
||||
# Generate image
|
||||
image = template_pipeline(
|
||||
pipe,
|
||||
prompt="a cat",
|
||||
seed=0, cfg_scale=4,
|
||||
height=1024, width=1024,
|
||||
template_inputs=[{xxx}],
|
||||
)
|
||||
image.save("image.png")
|
||||
@@ -0,0 +1,62 @@
|
||||
# Understanding Diffusion Templates
|
||||
|
||||
The Diffusion Templates framework is a controllable generation plugin framework in DiffSynth-Studio that provides additional controllable generation capabilities for Diffusion models.
|
||||
|
||||
## Framework Structure
|
||||
|
||||
The Diffusion Templates framework structure is shown below:
|
||||
|
||||
```mermaid
|
||||
flowchart TD;
|
||||
subgraph Template Pipeline
|
||||
si@{shape: text, label: "Template Input"}-->i1@{shape: text, label: "Template Input 1"};
|
||||
si@{shape: text, label: "Template Input"}-->i2@{shape: text, label: "Template Input 2"};
|
||||
si@{shape: text, label: "Template Input"}-->i3@{shape: text, label: "Template Input 3"};
|
||||
i1@{shape: text, label: "Template Input 1"}-->m1[Template Model 1]-->c1@{shape: text, label: "Template Cache 1"};
|
||||
i2@{shape: text, label: "Template Input 2"}-->m2[Template Model 2]-->c2@{shape: text, label: "Template Cache 2"};
|
||||
i3@{shape: text, label: "Template Input 3"}-->m3[Template Model 3]-->c3@{shape: text, label: "Template Cache 3"};
|
||||
c1-->c@{shape: text, label: "Template Cache"};
|
||||
c2-->c;
|
||||
c3-->c;
|
||||
end
|
||||
i@{shape: text, label: "Model Input"}-->m[Diffusion Pipeline]-->o@{shape: text, label: "Model Output"};
|
||||
c-->m;
|
||||
```
|
||||
|
||||
The framework contains these module designs:
|
||||
|
||||
* **Template Input**: Template model input. Format: Python dictionary with fields determined by each Template model (e.g., `{"scale": 0.8}`)
|
||||
* **Template Model**: Template model, loadable from ModelScope (`ModelConfig(model_id="xxx/xxx")`) or local path (`ModelConfig(path="xxx")`)
|
||||
* **Template Cache**: Template model output. Format: Python dictionary with fields matching base model Pipeline input parameters
|
||||
* **Template Pipeline**: Module for managing multiple Template models. Handles model loading and cache integration
|
||||
|
||||
When the Diffusion Templates framework is disabled, base model components (Text Encoder, DiT, VAE) are loaded into the Diffusion Pipeline. Model Input (prompt, height, width) produces Model Output (e.g., images).
|
||||
|
||||
When enabled, Template models are loaded into the Template Pipeline. The Template Pipeline outputs Template Cache (a subset of Diffusion Pipeline input parameters) for subsequent processing in the Diffusion Pipeline. This enables controllable generation by intercepting part of the Diffusion Pipeline's input parameters.
|
||||
|
||||
## Model Capability Medium
|
||||
|
||||
Template Cache is defined as a subset of Diffusion Pipeline input parameters, ensuring framework generality. We restrict Template model inputs to only be Diffusion Pipeline parameters. The KV-Cache is particularly suitable as a Diffusion medium:
|
||||
|
||||
* Proven effective in LLM Skills (prompts are converted to KV-Cache)
|
||||
* Has "high permission" in Diffusion models - can directly control image generation
|
||||
* Supports sequence-level concatenation for multiple Template models
|
||||
* Requires minimal development (add pipeline parameter and integrate to model)
|
||||
|
||||
Other potential Template mediums:
|
||||
* **Residual**: Used in ControlNet for point-to-point control, but has resolution limitations and potential conflicts when merging
|
||||
* **LoRA**: Treated as input parameters rather than model components
|
||||
|
||||
**Currently, we only support KV-Cache and LoRA as Template Cache mediums in FLUX.2 Pipeline, with plans to support more models and mediums in the future.**
|
||||
|
||||
## Template Model Format
|
||||
|
||||
A Template model has this structure:
|
||||
|
||||
```
|
||||
Template_Model
|
||||
├── model.py
|
||||
└── model.safetensors
|
||||
```
|
||||
|
||||
Where `model.py` is the entry point and `model.safetensors` contains model weights. For implementation details, see [Template Model Training](Template_Model_Training.md) or [existing Template models](https://modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Brightness).
|
||||
@@ -66,6 +66,15 @@ image.save("image.jpg")
|
||||
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||
|[DiffSynth-Studio/Template-KleinBase4B-Aesthetic](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Aesthetic)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/Template-KleinBase4B-Aesthetic.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Aesthetic.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/Template-KleinBase4B-Aesthetic.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/Template-KleinBase4B-Aesthetic.py)|-|-|
|
||||
|[DiffSynth-Studio/Template-KleinBase4B-Brightness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Brightness)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/Template-KleinBase4B-Brightness.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Brightness.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/Template-KleinBase4B-Brightness.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/Template-KleinBase4B-Brightness.py)|-|-|
|
||||
|[DiffSynth-Studio/Template-KleinBase4B-ControlNet](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-ControlNet)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/Template-KleinBase4B-ControlNet.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/Template-KleinBase4B-ControlNet.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/Template-KleinBase4B-ControlNet.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/Template-KleinBase4B-ControlNet.py)|-|-|
|
||||
|[DiffSynth-Studio/Template-KleinBase4B-Edit](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/Template-KleinBase4B-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/Template-KleinBase4B-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/Template-KleinBase4B-Edit.py)|-|-|
|
||||
|[DiffSynth-Studio/Template-KleinBase4B-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Inpaint)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/Template-KleinBase4B-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Inpaint.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/Template-KleinBase4B-Inpaint.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/Template-KleinBase4B-Inpaint.py)|-|-|
|
||||
|[DiffSynth-Studio/Template-KleinBase4B-PandaMeme](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-PandaMeme)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/Template-KleinBase4B-PandaMeme.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/Template-KleinBase4B-PandaMeme.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/Template-KleinBase4B-PandaMeme.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/Template-KleinBase4B-PandaMeme.py)|-|-|
|
||||
|[DiffSynth-Studio/Template-KleinBase4B-Sharpness](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Sharpness)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/Template-KleinBase4B-Sharpness.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Sharpness.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/Template-KleinBase4B-Sharpness.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/Template-KleinBase4B-Sharpness.py)|-|-|
|
||||
|[DiffSynth-Studio/Template-KleinBase4B-SoftRGB](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-SoftRGB)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/Template-KleinBase4B-SoftRGB.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/Template-KleinBase4B-SoftRGB.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/Template-KleinBase4B-SoftRGB.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/Template-KleinBase4B-SoftRGB.py)|-|-|
|
||||
|[DiffSynth-Studio/Template-KleinBase4B-Upscaler](https://www.modelscope.cn/models/DiffSynth-Studio/Template-KleinBase4B-Upscaler)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference/Template-KleinBase4B-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_inference_low_vram/Template-KleinBase4B-Upscaler.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/full/Template-KleinBase4B-Upscaler.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/flux2/model_training/validate_full/Template-KleinBase4B-Upscaler.py)|-|-|
|
||||
|
||||
Special Training Scripts:
|
||||
|
||||
|
||||
@@ -18,6 +18,9 @@ graph LR;
|
||||
I_want_to_explore_new_technologies_based_on_this_project-->sec5[Section 5: API Reference];
|
||||
I_want_to_explore_new_technologies_based_on_this_project-->sec6[Section 6: Academic Guide];
|
||||
I_encountered_a_problem-->sec7[Section 7: Frequently Asked Questions];
|
||||
I_want_to_explore_new_technologies_based_on_this_project-->sec6[Section 6: Diffusion Templates]
|
||||
I_want_to_explore_new_technologies_based_on_this_project-->sec8[Section 8: Academic Guide];
|
||||
I_encountered_a_problem-->sec9[Section 9: Frequently Asked Questions];
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -75,7 +78,15 @@ This section introduces the independent core module `diffsynth.core` in `DiffSyn
|
||||
* [`diffsynth.core.loader`](./API_Reference/core/loader.md): Model download and loading
|
||||
* [`diffsynth.core.vram`](./API_Reference/core/vram.md): VRAM management
|
||||
|
||||
## Section 6: Academic Guide
|
||||
## Section 6: Diffusion Templates
|
||||
|
||||
This section introduces the controllable generation plugin framework for Diffusion models, explaining the framework's operation mechanism and how to use Template models for inference and training.
|
||||
|
||||
* [Understanding Diffusion Templates](./Diffusion_Templates/Understanding_Diffusion_Templates.md)
|
||||
* [Template Model Inference](./Diffusion_Templates/Template_Model_Inference.md)
|
||||
* [Template Model Training](./Diffusion_Templates/Template_Model_Training.md)
|
||||
|
||||
## Section 7: Academic Guide
|
||||
|
||||
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
|
||||
|
||||
@@ -84,8 +95,8 @@ This section introduces how to use `DiffSynth-Studio` to train new models, helpi
|
||||
* Designing controllable generation models 【coming soon】
|
||||
* Creating new training paradigms 【coming soon】
|
||||
|
||||
## Section 7: Frequently Asked Questions
|
||||
## Section 8: Frequently Asked Questions
|
||||
|
||||
This section summarizes common developer questions. If you encounter issues during usage or development, please refer to this section. If you still cannot resolve the problem, please submit an issue on GitHub.
|
||||
|
||||
* [Frequently Asked Questions](./QA.md)
|
||||
* [Frequently Asked Questions](./QA.md)
|
||||
|
||||
@@ -60,6 +60,14 @@ Welcome to DiffSynth-Studio's Documentation
|
||||
API_Reference/core/loader
|
||||
API_Reference/core/vram
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Diffusion Templates
|
||||
|
||||
Diffusion_Templates/Understanding_Diffusion_Templates.md
|
||||
Diffusion_Templates/Template_Model_Inference.md
|
||||
Diffusion_Templates/Template_Model_Training.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Research Guide
|
||||
|
||||
Reference in New Issue
Block a user