Files
DiffSynth-Studio/docs/zh/Research_Tutorial/train_from_scratch.md
2026-02-10 20:59:47 +08:00

20 KiB
Raw Permalink Blame History

从零开始训练模型

DiffSynth-Studio 的训练引擎支持从零开始训练基础模型,本文介绍如何从零开始训练一个参数量仅为 0.1B 的小型文生图模型。

1. 构建模型结构

1.1 Diffusion 模型

从 UNet [1] [2] 到 DiT [3] [4]Diffusion 的主流模型结构经历了多次演变。通常,一个 Diffusion 模型的输入包括:

  • 图像张量(latents):图像的编码,由 VAE 模型产生,含有部分噪声
  • 文本张量(prompt_embeds):文本的编码,由文本编码器产生
  • 时间步(timestep):标量,用于标记当前处于 Diffusion 过程的哪个阶段

模型的输出是与图像张量形状相同的张量,表示模型预测的去噪方向,关于 Diffusion 模型理论的细节,请参考 Diffusion 模型基本原理。在本文中,我们构建一个仅含 0.1B 参数的 DiT 模型:AAADiT

模型结构代码
import torch, accelerate
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange, repeat

from transformers import AutoProcessor, AutoTokenizer
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
from diffsynth.models.general_modules import TimestepEmbeddings
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
from diffsynth.models.flux2_vae import Flux2VAE


class AAAPositionalEmbedding(torch.nn.Module):
    def __init__(self, height=16, width=16, dim=1024):
        super().__init__()
        self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
        self.text_emb = torch.nn.Parameter(torch.randn((dim,)))

    def forward(self, image, text):
        height, width = image.shape[-2:]
        image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
        image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
        image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
        text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
        text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
        emb = torch.concat([image_emb, text_emb], dim=1)
        return emb


class AAABlock(torch.nn.Module):
    def __init__(self, dim=1024, num_heads=32):
        super().__init__()
        self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
        self.to_q = torch.nn.Linear(dim, dim)
        self.to_k = torch.nn.Linear(dim, dim)
        self.to_v = torch.nn.Linear(dim, dim)
        self.to_out = torch.nn.Linear(dim, dim)
        self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
        self.ff = torch.nn.Sequential(
            torch.nn.Linear(dim, dim*3),
            torch.nn.SiLU(),
            torch.nn.Linear(dim*3, dim),
        )
        self.to_gate = torch.nn.Linear(dim, dim * 2)
        self.num_heads = num_heads

    def attention(self, emb, pos_emb):
        emb = self.norm_attn(emb + pos_emb)
        q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
        emb = attention_forward(
            q, k, v,
            q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
            dims={"n": self.num_heads},
        )
        emb = self.to_out(emb)
        return emb
    
    def feed_forward(self, emb, pos_emb):
        emb = self.norm_mlp(emb + pos_emb)
        emb = self.ff(emb)
        return emb
    
    def forward(self, emb, pos_emb, t_emb):
        gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
        emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
        emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
        return emb


class AAADiT(torch.nn.Module):
    def __init__(self, dim=1024):
        super().__init__()
        self.pos_embedder = AAAPositionalEmbedding(dim=dim)
        self.timestep_embedder = TimestepEmbeddings(256, dim)
        self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
        self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
        self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
        self.proj_out = torch.nn.Linear(dim, 128)

    def forward(
        self,
        latents,
        prompt_embeds,
        timestep,
        use_gradient_checkpointing=False,
        use_gradient_checkpointing_offload=False,
    ):
        pos_emb = self.pos_embedder(latents, prompt_embeds)
        t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
        image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
        text = self.text_embedder(prompt_embeds)
        emb = torch.concat([image, text], dim=1)
        for block_id, block in enumerate(self.blocks):
            emb = gradient_checkpoint_forward(
                block,
                use_gradient_checkpointing=use_gradient_checkpointing,
                use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
                emb=emb,
                pos_emb=pos_emb,
                t_emb=t_emb,
            )
        emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
        emb = self.proj_out(emb)
        emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
        return emb

1.2 编解码器模型

除了用于去噪的 Diffusion 模型以外,我们还需要另外两个模型:

  • 文本编码器:用于将文本编码为张量。我们采用 Qwen/Qwen3-0.6B 模型。
  • VAE 编解码器:编码器部分用于将图像编码为张量,解码器部分用于将图像张量解码为图像。我们采用 black-forest-labs/FLUX.2-klein-4B 中的 VAE 模型。

这两个模型的结构都已集成在 DiffSynth-Studio 中,分别位于 /diffsynth/models/z_image_text_encoder.py/diffsynth/models/flux2_vae.py,因此我们不需要修改任何代码。

2. 构建 Pipeline

我们在文档 接入 Pipeline 中介绍了如何构建一个模型 Pipeline对于本文中的模型我们也需要构建一个 Pipeline连接文本编码器、Diffusion 模型、VAE 编解码器。

Pipeline 代码
class AAAImagePipeline(BasePipeline):
    def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
        super().__init__(
            device=device, torch_dtype=torch_dtype,
            height_division_factor=16, width_division_factor=16,
        )
        self.scheduler = FlowMatchScheduler("FLUX.2")
        self.text_encoder: ZImageTextEncoder = None
        self.dit: AAADiT = None
        self.vae: Flux2VAE = None
        self.tokenizer: AutoProcessor = None
        self.in_iteration_models = ("dit",)
        self.units = [
            AAAUnit_PromptEmbedder(),
            AAAUnit_NoiseInitializer(),
            AAAUnit_InputImageEmbedder(),
        ]
        self.model_fn = model_fn_aaa
    
    @staticmethod
    def from_pretrained(
        torch_dtype: torch.dtype = torch.bfloat16,
        device: Union[str, torch.device] = "cuda",
        model_configs: list[ModelConfig] = [],
        tokenizer_config: ModelConfig = None,
        vram_limit: float = None,
    ):
        # Initialize pipeline
        pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
        model_pool = pipe.download_and_load_models(model_configs, vram_limit)
        
        # Fetch models
        pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
        pipe.dit = model_pool.fetch_model("aaa_dit")
        pipe.vae = model_pool.fetch_model("flux2_vae")
        if tokenizer_config is not None:
            tokenizer_config.download_if_necessary()
            pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
        
        # VRAM Management
        pipe.vram_management_enabled = pipe.check_vram_management_state()
        return pipe
    
    @torch.no_grad()
    def __call__(
        self,
        # Prompt
        prompt: str,
        negative_prompt: str = "",
        cfg_scale: float = 1.0,
        # Image
        input_image: Image.Image = None,
        denoising_strength: float = 1.0,
        # Shape
        height: int = 1024,
        width: int = 1024,
        # Randomness
        seed: int = None,
        rand_device: str = "cpu",
        # Steps
        num_inference_steps: int = 30,
        # Progress bar
        progress_bar_cmd = tqdm,
    ):
        self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)

        # Parameters
        inputs_posi = {"prompt": prompt}
        inputs_nega = {"negative_prompt": negative_prompt}
        inputs_shared = {
            "cfg_scale": cfg_scale,
            "input_image": input_image, "denoising_strength": denoising_strength,
            "height": height, "width": width,
            "seed": seed, "rand_device": rand_device,
            "num_inference_steps": num_inference_steps,
        }
        for unit in self.units:
            inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

        # Denoise
        self.load_models_to_device(self.in_iteration_models)
        models = {name: getattr(self, name) for name in self.in_iteration_models}
        for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
            timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
            noise_pred = self.cfg_guided_model_fn(
                self.model_fn, cfg_scale,
                inputs_shared, inputs_posi, inputs_nega,
                **models, timestep=timestep, progress_id=progress_id
            )
            inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
        
        # Decode
        self.load_models_to_device(['vae'])
        image = self.vae.decode(inputs_shared["latents"])
        image = self.vae_output_to_image(image)
        self.load_models_to_device([])

        return image


class AAAUnit_PromptEmbedder(PipelineUnit):
    def __init__(self):
        super().__init__(
            seperate_cfg=True,
            input_params_posi={"prompt": "prompt"},
            input_params_nega={"prompt": "negative_prompt"},
            output_params=("prompt_embeds",),
            onload_model_names=("text_encoder",)
        )
        self.hidden_states_layers = (-1,)

    def process(self, pipe: AAAImagePipeline, prompt):
        pipe.load_models_to_device(self.onload_model_names)
        text = pipe.tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )
        inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
        output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
        prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
        return {"prompt_embeds": prompt_embeds}


class AAAUnit_NoiseInitializer(PipelineUnit):
    def __init__(self):
        super().__init__(
            input_params=("height", "width", "seed", "rand_device"),
            output_params=("noise",),
        )

    def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
        noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
        return {"noise": noise}


class AAAUnit_InputImageEmbedder(PipelineUnit):
    def __init__(self):
        super().__init__(
            input_params=("input_image", "noise"),
            output_params=("latents", "input_latents"),
            onload_model_names=("vae",)
        )

    def process(self, pipe: AAAImagePipeline, input_image, noise):
        if input_image is None:
            return {"latents": noise, "input_latents": None}
        pipe.load_models_to_device(['vae'])
        image = pipe.preprocess_image(input_image)
        input_latents = pipe.vae.encode(image)
        if pipe.scheduler.training:
            return {"latents": noise, "input_latents": input_latents}
        else:
            latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
            return {"latents": latents, "input_latents": input_latents}


def model_fn_aaa(
    dit: AAADiT,
    latents=None,
    prompt_embeds=None,
    timestep=None,
    use_gradient_checkpointing=False,
    use_gradient_checkpointing_offload=False,
    **kwargs,
):
    model_output = dit(
        latents,
        prompt_embeds,
        timestep,
        use_gradient_checkpointing=use_gradient_checkpointing,
        use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
    )
    return model_output

3. 准备数据集

为了快速验证训练效果,我们使用数据集 宝可梦-第一世代,这个数据集转载自开源项目 pokemon-dataset-zh,包含从妙蛙种子到梦幻的 151 个第一世代宝可梦。如果你想使用其他数据集,请参考文档 准备数据集diffsynth.core.data

modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data

4. 开始训练

训练过程可使用 Pipeline 快速实现,我们已将完整的代码放在 ../Research_Tutorial/train_from_scratch.py,可直接通过 python docs/zh/Research_Tutorial/train_from_scratch.py 开始单 GPU 训练。

如需开启多 GPU 并行训练,请运行 accelerate config 设置相关参数,然后使用命令 accelerate launch docs/zh/Research_Tutorial/train_from_scratch.py 开始训练。

这个训练脚本没有设置停止条件,请在需要时手动关闭。模型在训练大约 6 万步后收敛,单 GPU 训练需要 1020 小时。

训练代码
class AAATrainingModule(DiffusionTrainingModule):
    def __init__(self, device):
        super().__init__()
        self.pipe = AAAImagePipeline.from_pretrained(
            torch_dtype=torch.bfloat16,
            device=device,
            model_configs=[
                ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
                ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
            ],
            tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
        )
        self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
        self.pipe.freeze_except(["dit"])
        self.pipe.scheduler.set_timesteps(1000, training=True)

    def forward(self, data):
        inputs_posi = {"prompt": data["prompt"]}
        inputs_nega = {"negative_prompt": ""}
        inputs_shared = {
            "input_image": data["image"],
            "height": data["image"].size[1],
            "width": data["image"].size[0],
            "cfg_scale": 1,
            "use_gradient_checkpointing": False,
            "use_gradient_checkpointing_offload": False,
        }
        for unit in self.pipe.units:
            inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
        loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
        return loss


if __name__ == "__main__":
    accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
    dataset = UnifiedDataset(
        base_path="data/images",
        metadata_path="data/metadata_merged.csv",
        max_data_items=10000000,
        data_file_keys=("image",),
        main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
    )
    model = AAATrainingModule(device=accelerator.device)
    model_logger = ModelLogger(
        "models/AAA/v1",
        remove_prefix_in_ckpt="pipe.dit.",
    )
    launch_training_task(
        accelerator, dataset, model, model_logger,
        learning_rate=2e-4,
        num_workers=4,
        save_steps=50000,
        num_epochs=999999,
    )

5. 验证训练效果

如果你不想等待模型训练完成,可以直接下载我们预先训练好的模型

modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel

加载模型

from diffsynth import load_model

pipe = AAAImagePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
        ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
    ],
    tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
)
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")

模型推理,生成第一世代宝可梦“御三家”,此时模型生成的图像内容与训练数据基本一致。

for seed, prompt in enumerate([
    "green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
    "orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
    "蓝色,米色,棕色,乌龟,水系,龟壳,大眼睛,短四肢,卷曲尾巴",
]):
    image = pipe(
        prompt=prompt,
        negative_prompt=" ",
        num_inference_steps=30,
        cfg_scale=10,
        seed=seed,
        height=256, width=256,
    )
    image.save(f"image_{seed}.jpg")
Image Image Image

模型推理,生成具有“锐利爪子”的宝可梦,此时不同的随机种子能够产生不同的图像结果。

for seed, prompt in enumerate([
    "sharp claws",
    "sharp claws",
    "sharp claws",
]):
    image = pipe(
        prompt=prompt,
        negative_prompt=" ",
        num_inference_steps=30,
        cfg_scale=10,
        seed=seed+4,
        height=256, width=256,
    )
    image.save(f"image_sharp_claws_{seed}.jpg")
Image Image Image

现在,我们获得了一个 0.1B 的小型文生图模型,这个模型已经能够生成 151 个宝可梦但无法生成其他图像内容。如果在此基础上增加数据量、模型参数量、GPU 数量,你就可以训练出一个更强大的文生图模型!