# Training Models from Scratch DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch. ## 1. Building Model Architecture ### 1.1 Diffusion Model From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include: * Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise * Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder * Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](../Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`.
Model Architecture Code ```python import torch, accelerate from PIL import Image from typing import Union from tqdm import tqdm from einops import rearrange, repeat from transformers import AutoProcessor, AutoTokenizer from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit from diffsynth.models.general_modules import TimestepEmbeddings from diffsynth.models.z_image_text_encoder import ZImageTextEncoder from diffsynth.models.flux2_vae import Flux2VAE class AAAPositionalEmbedding(torch.nn.Module): def __init__(self, height=16, width=16, dim=1024): super().__init__() self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width))) self.text_emb = torch.nn.Parameter(torch.randn((dim,))) def forward(self, image, text): height, width = image.shape[-2:] image_emb = self.image_emb.to(device=image.device, dtype=image.dtype) image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear") image_emb = rearrange(image_emb, "B C H W -> B (H W) C") text_emb = self.text_emb.to(device=text.device, dtype=text.dtype) text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1]) emb = torch.concat([image_emb, text_emb], dim=1) return emb class AAABlock(torch.nn.Module): def __init__(self, dim=1024, num_heads=32): super().__init__() self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False) self.to_q = torch.nn.Linear(dim, dim) self.to_k = torch.nn.Linear(dim, dim) self.to_v = torch.nn.Linear(dim, dim) self.to_out = torch.nn.Linear(dim, dim) self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False) self.ff = torch.nn.Sequential( torch.nn.Linear(dim, dim*3), torch.nn.SiLU(), torch.nn.Linear(dim*3, dim), ) self.to_gate = torch.nn.Linear(dim, dim * 2) self.num_heads = num_heads def attention(self, emb, pos_emb): emb = self.norm_attn(emb + pos_emb) q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb) emb = attention_forward( q, k, v, q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)", dims={"n": self.num_heads}, ) emb = self.to_out(emb) return emb def feed_forward(self, emb, pos_emb): emb = self.norm_mlp(emb + pos_emb) emb = self.ff(emb) return emb def forward(self, emb, pos_emb, t_emb): gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1) emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn) emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp) return emb class AAADiT(torch.nn.Module): def __init__(self, dim=1024): super().__init__() self.pos_embedder = AAAPositionalEmbedding(dim=dim) self.timestep_embedder = TimestepEmbeddings(256, dim) self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim)) self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim)) self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)]) self.proj_out = torch.nn.Linear(dim, 128) def forward( self, latents, prompt_embeds, timestep, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, ): pos_emb = self.pos_embedder(latents, prompt_embeds) t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1) image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C")) text = self.text_embedder(prompt_embeds) emb = torch.concat([image, text], dim=1) for block_id, block in enumerate(self.blocks): emb = gradient_checkpoint_forward( block, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, emb=emb, pos_emb=pos_emb, t_emb=t_emb, ) emb = emb[:, :latents.shape[-1] * latents.shape[-2]] emb = self.proj_out(emb) emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1]) return emb ```
### 1.2 Encoder-Decoder Models Besides the Diffusion model used for denoising, we also need two other models: * Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model. * VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B). The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/models/flux2_vae.py), so we don't need to modify any code. ## 2. Building Pipeline We introduced how to build a model Pipeline in the document [Integrating Pipeline](../Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.
Pipeline Code ```python class AAAImagePipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, ) self.scheduler = FlowMatchScheduler("FLUX.2") self.text_encoder: ZImageTextEncoder = None self.dit: AAADiT = None self.vae: Flux2VAE = None self.tokenizer: AutoProcessor = None self.in_iteration_models = ("dit",) self.units = [ AAAUnit_PromptEmbedder(), AAAUnit_NoiseInitializer(), AAAUnit_InputImageEmbedder(), ] self.model_fn = model_fn_aaa @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, device: Union[str, torch.device] = "cuda", model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = None, vram_limit: float = None, ): # Initialize pipeline pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype) model_pool = pipe.download_and_load_models(model_configs, vram_limit) # Fetch models pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") pipe.dit = model_pool.fetch_model("aaa_dit") pipe.vae = model_pool.fetch_model("flux2_vae") if tokenizer_config is not None: tokenizer_config.download_if_necessary() pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state() return pipe @torch.no_grad() def __call__( self, # Prompt prompt: str, negative_prompt: str = "", cfg_scale: float = 1.0, # Image input_image: Image.Image = None, denoising_strength: float = 1.0, # Shape height: int = 1024, width: int = 1024, # Randomness seed: int = None, rand_device: str = "cpu", # Steps num_inference_steps: int = 30, # Progress bar progress_bar_cmd = tqdm, ): self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) # Parameters inputs_posi = {"prompt": prompt} inputs_nega = {"negative_prompt": negative_prompt} inputs_shared = { "cfg_scale": cfg_scale, "input_image": input_image, "denoising_strength": denoising_strength, "height": height, "width": width, "seed": seed, "rand_device": rand_device, "num_inference_steps": num_inference_steps, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) noise_pred = self.cfg_guided_model_fn( self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **models, timestep=timestep, progress_id=progress_id ) inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) # Decode self.load_models_to_device(['vae']) image = self.vae.decode(inputs_shared["latents"]) image = self.vae_output_to_image(image) self.load_models_to_device([]) return image class AAAUnit_PromptEmbedder(PipelineUnit): def __init__(self): super().__init__( seperate_cfg=True, input_params_posi={"prompt": "prompt"}, input_params_nega={"prompt": "negative_prompt"}, output_params=("prompt_embeds",), onload_model_names=("text_encoder",) ) self.hidden_states_layers = (-1,) def process(self, pipe: AAAImagePipeline, prompt): pipe.load_models_to_device(self.onload_model_names) text = pipe.tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True, enable_thinking=False, ) inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device) output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False) prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1) return {"prompt_embeds": prompt_embeds} class AAAUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( input_params=("height", "width", "seed", "rand_device"), output_params=("noise",), ) def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device): noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) return {"noise": noise} class AAAUnit_InputImageEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "noise"), output_params=("latents", "input_latents"), onload_model_names=("vae",) ) def process(self, pipe: AAAImagePipeline, input_image, noise): if input_image is None: return {"latents": noise, "input_latents": None} pipe.load_models_to_device(['vae']) image = pipe.preprocess_image(input_image) input_latents = pipe.vae.encode(image) if pipe.scheduler.training: return {"latents": noise, "input_latents": input_latents} else: latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) return {"latents": latents, "input_latents": input_latents} def model_fn_aaa( dit: AAADiT, latents=None, prompt_embeds=None, timestep=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, ): model_output = dit( latents, prompt_embeds, timestep, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) return model_output ```
## 3. Preparing Dataset To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](../Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](../API_Reference/core/data.md). ```shell modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data ``` ### 4. Start Training The training process can be quickly implemented using Pipeline. We have placed the complete code at [../Research_Tutorial/train_from_scratch.py](https://github.com/modelscope/DiffSynth-Studio/blob/main/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training. To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training. This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.
Training Code ```python class AAATrainingModule(DiffusionTrainingModule): def __init__(self, device): super().__init__() self.pipe = AAAImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device=device, model_configs=[ ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), ], tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), ) self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device) self.pipe.freeze_except(["dit"]) self.pipe.scheduler.set_timesteps(1000, training=True) def forward(self, data): inputs_posi = {"prompt": data["prompt"]} inputs_nega = {"negative_prompt": ""} inputs_shared = { "input_image": data["image"], "height": data["image"].size[1], "width": data["image"].size[0], "cfg_scale": 1, "use_gradient_checkpointing": False, "use_gradient_checkpointing_offload": False, } for unit in self.pipe.units: inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) return loss if __name__ == "__main__": accelerator = accelerate.Accelerator(gradient_accumulation_steps=1) dataset = UnifiedDataset( base_path="data/images", metadata_path="data/metadata_merged.csv", max_data_items=10000000, data_file_keys=("image",), main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256) ) model = AAATrainingModule(device=accelerator.device) model_logger = ModelLogger( "models/AAA/v1", remove_prefix_in_ckpt="pipe.dit.", ) launch_training_task( accelerator, dataset, model, model_logger, learning_rate=2e-4, num_workers=4, save_steps=50000, num_epochs=999999, ) ```
## 5. Verifying Training Results If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). ```shell modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel ``` Loading the model ```python from diffsynth import load_model pipe = AAAImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), ], tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"), ) pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda") ``` Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data. ```python for seed, prompt in enumerate([ "green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws", "orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws", "blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail", ]): image = pipe( prompt=prompt, negative_prompt=" ", num_inference_steps=30, cfg_scale=10, seed=seed, height=256, width=256, ) image.save(f"image_{seed}.jpg") ``` |![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)| |-|-|-| Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results. ```python for seed, prompt in enumerate([ "sharp claws", "sharp claws", "sharp claws", ]): image = pipe( prompt=prompt, negative_prompt=" ", num_inference_steps=30, cfg_scale=10, seed=seed+4, height=256, width=256, ) image.save(f"image_sharp_claws_{seed}.jpg") ``` |![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)| |-|-|-| Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!