Merge branch 'main' into ltx-2

This commit is contained in:
Zhongjie Duan
2026-02-03 13:06:44 +08:00
committed by GitHub
34 changed files with 2132 additions and 37 deletions

View File

@@ -85,6 +85,7 @@ graph LR;
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |

View File

@@ -50,9 +50,14 @@ image.save("image.jpg")
## Model Overview
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
| - | - | - | - | - | - | - |
| [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) | [code](/examples/z_image/model_inference/Z-Image-Turbo.py) | [code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/full/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py) |
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|-|-|-|-|-|-|-|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
Special Training Scripts:
@@ -75,6 +80,9 @@ Input parameters for `ZImagePipeline` inference include:
* `seed`: Random seed. Default is `None`, meaning completely random.
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
* `num_inference_steps`: Number of inference steps, default value is 8.
* `controlnet_inputs`: Inputs for ControlNet models.
* `edit_image`: Edit images for image editing models, supporting multiple images.
* `positive_only_lora`: LoRA weights used only in positive prompts.
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.

View File

@@ -77,7 +77,7 @@ This section introduces the independent core module `diffsynth.core` in `DiffSyn
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
* Training models from scratch 【coming soon】
* [Training models from scratch](/docs/en/Research_Tutorial/train_from_scratch.md)
* Inference improvement techniques 【coming soon】
* Designing controllable generation models 【coming soon】
* Creating new training paradigms 【coming soon】

View File

@@ -0,0 +1,476 @@
# Training Models from Scratch
DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch.
## 1. Building Model Architecture
### 1.1 Diffusion Model
From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include:
* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise
* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder
* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at
The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`.
<details>
<summary>Model Architecture Code</summary>
```python
import torch, accelerate
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange, repeat
from transformers import AutoProcessor, AutoTokenizer
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE
class AAAPositionalEmbedding(torch.nn.Module):
def __init__(self, height=16, width=16, dim=1024):
super().__init__()
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
def forward(self, image, text):
height, width = image.shape[-2:]
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
emb = torch.concat([image_emb, text_emb], dim=1)
return emb
class AAABlock(torch.nn.Module):
def __init__(self, dim=1024, num_heads=32):
super().__init__()
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.to_q = torch.nn.Linear(dim, dim)
self.to_k = torch.nn.Linear(dim, dim)
self.to_v = torch.nn.Linear(dim, dim)
self.to_out = torch.nn.Linear(dim, dim)
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*3),
torch.nn.SiLU(),
torch.nn.Linear(dim*3, dim),
)
self.to_gate = torch.nn.Linear(dim, dim * 2)
self.num_heads = num_heads
def attention(self, emb, pos_emb):
emb = self.norm_attn(emb + pos_emb)
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
emb = attention_forward(
q, k, v,
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
dims={"n": self.num_heads},
)
emb = self.to_out(emb)
return emb
def feed_forward(self, emb, pos_emb):
emb = self.norm_mlp(emb + pos_emb)
emb = self.ff(emb)
return emb
def forward(self, emb, pos_emb, t_emb):
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
return emb
class AAADiT(torch.nn.Module):
def __init__(self, dim=1024):
super().__init__()
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
self.timestep_embedder = TimestepEmbeddings(256, dim)
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
self.proj_out = torch.nn.Linear(dim, 128)
def forward(
self,
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
pos_emb = self.pos_embedder(latents, prompt_embeds)
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
text = self.text_embedder(prompt_embeds)
emb = torch.concat([image, text], dim=1)
for block_id, block in enumerate(self.blocks):
emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
emb=emb,
pos_emb=pos_emb,
t_emb=t_emb,
)
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
emb = self.proj_out(emb)
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
return emb
```
</details>
### 1.2 Encoder-Decoder Models
Besides the Diffusion model used for denoising, we also need two other models:
* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.
* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).
The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py), so we don't need to modify any code.
## 2. Building Pipeline
We introduced how to build a model Pipeline in the document [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.
<details>
<summary>Pipeline Code</summary>
```python
class AAAImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: ZImageTextEncoder = None
self.dit: AAADiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None
self.in_iteration_models = ("dit",)
self.units = [
AAAUnit_PromptEmbedder(),
AAAUnit_NoiseInitializer(),
AAAUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_aaa
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
# Initialize pipeline
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("aaa_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class AAAUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
self.hidden_states_layers = (-1,)
def process(self, pipe: AAAImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
text = pipe.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
return {"prompt_embeds": prompt_embeds}
class AAAUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AAAUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: AAAImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
def model_fn_aaa(
dit: AAADiT,
latents=None,
prompt_embeds=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
model_output = dit(
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return model_output
```
</details>
## 3. Preparing Dataset
To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](/docs/en/Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md).
```shell
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
```
### 4. Start Training
The training process can be quickly implemented using Pipeline. We have placed the complete code at [/docs/en/Research_Tutorial/train_from_scratch.py](/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.
To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.
This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.
<details>
<summary>Training Code</summary>
```python
class AAATrainingModule(DiffusionTrainingModule):
def __init__(self, device):
super().__init__()
self.pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
self.pipe.freeze_except(["dit"])
self.pipe.scheduler.set_timesteps(1000, training=True)
def forward(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
"cfg_scale": 1,
"use_gradient_checkpointing": False,
"use_gradient_checkpointing_offload": False,
}
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
return loss
if __name__ == "__main__":
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
dataset = UnifiedDataset(
base_path="data/images",
metadata_path="data/metadata_merged.csv",
max_data_items=10000000,
data_file_keys=("image",),
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
)
model = AAATrainingModule(device=accelerator.device)
model_logger = ModelLogger(
"models/AAA/v1",
remove_prefix_in_ckpt="pipe.dit.",
)
launch_training_task(
accelerator, dataset, model, model_logger,
learning_rate=2e-4,
num_workers=4,
save_steps=50000,
num_epochs=999999,
)
```
</details>
## 5. Verifying Training Results
If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel).
```shell
modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
```
Loading the model
```python
from diffsynth import load_model
pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
```
Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data.
```python
for seed, prompt in enumerate([
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
"blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail",
]):
image = pipe(
prompt=prompt,
negative_prompt=" ",
num_inference_steps=30,
cfg_scale=10,
seed=seed,
height=256, width=256,
)
image.save(f"image_{seed}.jpg")
```
|![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)|
|-|-|-|
Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results.
```python
for seed, prompt in enumerate([
"sharp claws",
"sharp claws",
"sharp claws",
]):
image = pipe(
prompt=prompt,
negative_prompt=" ",
num_inference_steps=30,
cfg_scale=10,
seed=seed+4,
height=256, width=256,
)
image.save(f"image_sharp_claws_{seed}.jpg")
```
|![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)|
|-|-|-|
Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!

View File

@@ -0,0 +1,341 @@
import torch, accelerate
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange, repeat
from transformers import AutoProcessor, AutoTokenizer
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE
class AAAPositionalEmbedding(torch.nn.Module):
def __init__(self, height=16, width=16, dim=1024):
super().__init__()
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
def forward(self, image, text):
height, width = image.shape[-2:]
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
emb = torch.concat([image_emb, text_emb], dim=1)
return emb
class AAABlock(torch.nn.Module):
def __init__(self, dim=1024, num_heads=32):
super().__init__()
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.to_q = torch.nn.Linear(dim, dim)
self.to_k = torch.nn.Linear(dim, dim)
self.to_v = torch.nn.Linear(dim, dim)
self.to_out = torch.nn.Linear(dim, dim)
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*3),
torch.nn.SiLU(),
torch.nn.Linear(dim*3, dim),
)
self.to_gate = torch.nn.Linear(dim, dim * 2)
self.num_heads = num_heads
def attention(self, emb, pos_emb):
emb = self.norm_attn(emb + pos_emb)
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
emb = attention_forward(
q, k, v,
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
dims={"n": self.num_heads},
)
emb = self.to_out(emb)
return emb
def feed_forward(self, emb, pos_emb):
emb = self.norm_mlp(emb + pos_emb)
emb = self.ff(emb)
return emb
def forward(self, emb, pos_emb, t_emb):
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
return emb
class AAADiT(torch.nn.Module):
def __init__(self, dim=1024):
super().__init__()
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
self.timestep_embedder = TimestepEmbeddings(256, dim)
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
self.proj_out = torch.nn.Linear(dim, 128)
def forward(
self,
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
pos_emb = self.pos_embedder(latents, prompt_embeds)
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
text = self.text_embedder(prompt_embeds)
emb = torch.concat([image, text], dim=1)
for block_id, block in enumerate(self.blocks):
emb = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
emb=emb,
pos_emb=pos_emb,
t_emb=t_emb,
)
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
emb = self.proj_out(emb)
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
return emb
class AAAImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: ZImageTextEncoder = None
self.dit: AAADiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoProcessor = None
self.in_iteration_models = ("dit",)
self.units = [
AAAUnit_PromptEmbedder(),
AAAUnit_NoiseInitializer(),
AAAUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_aaa
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
# Initialize pipeline
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
pipe.dit = model_pool.fetch_model("aaa_dit")
pipe.vae = model_pool.fetch_model("flux2_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 30,
# Progress bar
progress_bar_cmd = tqdm,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"input_image": input_image, "denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
image = self.vae.decode(inputs_shared["latents"])
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class AAAUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
self.hidden_states_layers = (-1,)
def process(self, pipe: AAAImagePipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
text = pipe.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
return {"prompt_embeds": prompt_embeds}
class AAAUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class AAAUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: AAAImagePipeline, input_image, noise):
if input_image is None:
return {"latents": noise, "input_latents": None}
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(image)
if pipe.scheduler.training:
return {"latents": noise, "input_latents": input_latents}
else:
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"latents": latents, "input_latents": input_latents}
def model_fn_aaa(
dit: AAADiT,
latents=None,
prompt_embeds=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
model_output = dit(
latents,
prompt_embeds,
timestep,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
return model_output
class AAATrainingModule(DiffusionTrainingModule):
def __init__(self, device):
super().__init__()
self.pipe = AAAImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
self.pipe.freeze_except(["dit"])
self.pipe.scheduler.set_timesteps(1000, training=True)
def forward(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
"cfg_scale": 1,
"use_gradient_checkpointing": False,
"use_gradient_checkpointing_offload": False,
}
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
return loss
if __name__ == "__main__":
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
dataset = UnifiedDataset(
base_path="data/images",
metadata_path="data/metadata_merged.csv",
max_data_items=10000000,
data_file_keys=("image",),
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
)
model = AAATrainingModule(device=accelerator.device)
model_logger = ModelLogger(
"models/AAA/v1",
remove_prefix_in_ckpt="pipe.dit.",
)
launch_training_task(
accelerator, dataset, model, model_logger,
learning_rate=2e-4,
num_workers=4,
save_steps=50000,
num_epochs=999999,
)

View File

@@ -6,7 +6,7 @@ This document introduces the basic principles of Diffusion models to help you un
Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$.
(Figure)
![Image](https://github.com/user-attachments/assets/6471ae4c-a635-4924-8b36-b0bd4d42043d)
This process is intuitive, but to understand the details, we need to answer several questions:
@@ -28,7 +28,7 @@ As for the intermediate values $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma
At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$.
(Figure)
![Image](https://github.com/user-attachments/assets/e25a2f71-123c-4e18-8b34-3a066af15667)
## How is the iterative denoising computation performed?
@@ -40,8 +40,6 @@ Before understanding the iterative denoising computation, we need to clarify wha
Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure.
(Figure)
The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising).
Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$:
@@ -91,8 +89,6 @@ After understanding the iterative denoising process, we next consider how to tra
The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training.
(Figure)
The following is pseudocode for the training process:
> Obtain data sample $x_0$ and guidance condition $c$ from the dataset
@@ -113,7 +109,7 @@ The following is pseudocode for the training process:
From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model.
(Figure)
![Image](https://github.com/user-attachments/assets/43855430-6427-4aca-83a0-f684e01438b1)
### Data Encoder-Decoder