From 98ab238340449da1e37e82a09561fe5562369a0d Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 2 Feb 2026 14:28:26 +0800 Subject: [PATCH] add research tutorial sec 1 --- README.md | 2 + README_zh.md | 2 + diffsynth/configs/model_configs.py | 8 + diffsynth/models/z_image_text_encoder.py | 30 ++ docs/en/README.md | 2 +- .../Research_Tutorial/train_from_scratch.md | 476 +++++++++++++++++ .../Research_Tutorial/train_from_scratch.py | 341 +++++++++++++ docs/zh/README.md | 2 +- .../Research_Tutorial/train_from_scratch.md | 477 ++++++++++++++++++ .../Research_Tutorial/train_from_scratch.py | 341 +++++++++++++ 10 files changed, 1679 insertions(+), 2 deletions(-) create mode 100644 docs/en/Research_Tutorial/train_from_scratch.md create mode 100644 docs/en/Research_Tutorial/train_from_scratch.py create mode 100644 docs/zh/Research_Tutorial/train_from_scratch.md create mode 100644 docs/zh/Research_Tutorial/train_from_scratch.py diff --git a/README.md b/README.md index 8d2262f..5127c1b 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ We believe that a well-developed open-source code framework can lower the thresh > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. +- **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models. + - **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md). - **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available. diff --git a/README_zh.md b/README_zh.md index ee30d85..b3e916a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -33,6 +33,8 @@ DiffSynth 目前包括两个开源项目: > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 +- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。 + - **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。 - **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。 diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index c93f5e9..cc434ed 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -589,6 +589,14 @@ z_image_series = [ "model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel", "extra_kwargs": {"compress_dim": 128}, }, + { + # Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors") + "model_hash": "1392adecee344136041e70553f875f31", + "model_name": "z_image_text_encoder", + "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder", + "extra_kwargs": {"model_size": "0.6B"}, + "state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter", + }, ] MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series diff --git a/diffsynth/models/z_image_text_encoder.py b/diffsynth/models/z_image_text_encoder.py index 4d6271d..6f3e6c0 100644 --- a/diffsynth/models/z_image_text_encoder.py +++ b/diffsynth/models/z_image_text_encoder.py @@ -6,6 +6,36 @@ class ZImageTextEncoder(torch.nn.Module): def __init__(self, model_size="4B"): super().__init__() config_dict = { + "0.6B": Qwen3Config(**{ + "architectures": [ + "Qwen3ForCausalLM" + ], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 40960, + "max_window_layers": 28, + "model_type": "qwen3", + "num_attention_heads": 16, + "num_hidden_layers": 28, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 1000000, + "sliding_window": None, + "tie_word_embeddings": True, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936 + }), "4B": Qwen3Config(**{ "architectures": [ "Qwen3ForCausalLM" diff --git a/docs/en/README.md b/docs/en/README.md index 39ae439..e968637 100644 --- a/docs/en/README.md +++ b/docs/en/README.md @@ -77,7 +77,7 @@ This section introduces the independent core module `diffsynth.core` in `DiffSyn This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies. -* Training models from scratch 【coming soon】 +* [Training models from scratch](/docs/en/Research_Tutorial/train_from_scratch.md) * Inference improvement techniques 【coming soon】 * Designing controllable generation models 【coming soon】 * Creating new training paradigms 【coming soon】 diff --git a/docs/en/Research_Tutorial/train_from_scratch.md b/docs/en/Research_Tutorial/train_from_scratch.md new file mode 100644 index 0000000..6d5ff76 --- /dev/null +++ b/docs/en/Research_Tutorial/train_from_scratch.md @@ -0,0 +1,476 @@ +# Training Models from Scratch + +DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch. + +## 1. Building Model Architecture + +### 1.1 Diffusion Model + +From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include: + +* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise +* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder +* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at + +The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`. + +
+Model Architecture Code + +```python +import torch, accelerate +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat + +from transformers import AutoProcessor, AutoTokenizer +from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model +from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task +from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit +from diffsynth.models.general_modules import TimestepEmbeddings +from diffsynth.models.z_image_text_encoder import ZImageTextEncoder +from diffsynth.models.flux2_vae import Flux2VAE + + +class AAAPositionalEmbedding(torch.nn.Module): + def __init__(self, height=16, width=16, dim=1024): + super().__init__() + self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width))) + self.text_emb = torch.nn.Parameter(torch.randn((dim,))) + + def forward(self, image, text): + height, width = image.shape[-2:] + image_emb = self.image_emb.to(device=image.device, dtype=image.dtype) + image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear") + image_emb = rearrange(image_emb, "B C H W -> B (H W) C") + text_emb = self.text_emb.to(device=text.device, dtype=text.dtype) + text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1]) + emb = torch.concat([image_emb, text_emb], dim=1) + return emb + + +class AAABlock(torch.nn.Module): + def __init__(self, dim=1024, num_heads=32): + super().__init__() + self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.to_q = torch.nn.Linear(dim, dim) + self.to_k = torch.nn.Linear(dim, dim) + self.to_v = torch.nn.Linear(dim, dim) + self.to_out = torch.nn.Linear(dim, dim) + self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.ff = torch.nn.Sequential( + torch.nn.Linear(dim, dim*3), + torch.nn.SiLU(), + torch.nn.Linear(dim*3, dim), + ) + self.to_gate = torch.nn.Linear(dim, dim * 2) + self.num_heads = num_heads + + def attention(self, emb, pos_emb): + emb = self.norm_attn(emb + pos_emb) + q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb) + emb = attention_forward( + q, k, v, + q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)", + dims={"n": self.num_heads}, + ) + emb = self.to_out(emb) + return emb + + def feed_forward(self, emb, pos_emb): + emb = self.norm_mlp(emb + pos_emb) + emb = self.ff(emb) + return emb + + def forward(self, emb, pos_emb, t_emb): + gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1) + emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn) + emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp) + return emb + + +class AAADiT(torch.nn.Module): + def __init__(self, dim=1024): + super().__init__() + self.pos_embedder = AAAPositionalEmbedding(dim=dim) + self.timestep_embedder = TimestepEmbeddings(256, dim) + self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim)) + self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim)) + self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)]) + self.proj_out = torch.nn.Linear(dim, 128) + + def forward( + self, + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + pos_emb = self.pos_embedder(latents, prompt_embeds) + t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1) + image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C")) + text = self.text_embedder(prompt_embeds) + emb = torch.concat([image, text], dim=1) + for block_id, block in enumerate(self.blocks): + emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + emb=emb, + pos_emb=pos_emb, + t_emb=t_emb, + ) + emb = emb[:, :latents.shape[-1] * latents.shape[-2]] + emb = self.proj_out(emb) + emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1]) + return emb +``` + +
+ +### 1.2 Encoder-Decoder Models + +Besides the Diffusion model used for denoising, we also need two other models: + +* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model. +* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B). + +The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py), so we don't need to modify any code. + +## 2. Building Pipeline + +We introduced how to build a model Pipeline in the document [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder. + +
+Pipeline Code + +```python +class AAAImagePipeline(BasePipeline): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: ZImageTextEncoder = None + self.dit: AAADiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + AAAUnit_PromptEmbedder(), + AAAUnit_NoiseInitializer(), + AAAUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_aaa + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("aaa_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class AAAUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + self.hidden_states_layers = (-1,) + + def process(self, pipe: AAAImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + text = pipe.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device) + output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False) + prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1) + return {"prompt_embeds": prompt_embeds} + + +class AAAUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class AAAUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: AAAImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +def model_fn_aaa( + dit: AAADiT, + latents=None, + prompt_embeds=None, + timestep=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + model_output = dit( + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + return model_output +``` + +
+ +## 3. Preparing Dataset + +To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](/docs/en/Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md). + +```shell +modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data +``` + +### 4. Start Training + +The training process can be quickly implemented using Pipeline. We have placed the complete code at [/docs/en/Research_Tutorial/train_from_scratch.py](/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training. + +To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training. + +This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training. + +
+Training Code + +```python +class AAATrainingModule(DiffusionTrainingModule): + def __init__(self, device): + super().__init__() + self.pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern=""), + ) + self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device) + self.pipe.freeze_except(["dit"]) + self.pipe.scheduler.set_timesteps(1000, training=True) + + def forward(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "cfg_scale": 1, + "use_gradient_checkpointing": False, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss + + +if __name__ == "__main__": + accelerator = accelerate.Accelerator(gradient_accumulation_steps=1) + dataset = UnifiedDataset( + base_path="data/images", + metadata_path="data/metadata_merged.csv", + max_data_items=10000000, + data_file_keys=("image",), + main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256) + ) + model = AAATrainingModule(device=accelerator.device) + model_logger = ModelLogger( + "models/AAA/v1", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=2e-4, + num_workers=4, + save_steps=50000, + num_epochs=999999, + ) +``` + +
+ +## 5. Verifying Training Results + +If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). + +```shell +modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel +``` + +Loading the model + +```python +from diffsynth import load_model + +pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern=""), +) +pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda") +``` + +Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data. + +```python +for seed, prompt in enumerate([ + "green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws", + "orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws", + "blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail", +]): + image = pipe( + prompt=prompt, + negative_prompt=" ", + num_inference_steps=30, + cfg_scale=10, + seed=seed, + height=256, width=256, + ) + image.save(f"image_{seed}.jpg") +``` + +|![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)| +|-|-|-| + +Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results. + +```python +for seed, prompt in enumerate([ + "sharp claws", + "sharp claws", + "sharp claws", +]): + image = pipe( + prompt=prompt, + negative_prompt=" ", + num_inference_steps=30, + cfg_scale=10, + seed=seed+4, + height=256, width=256, + ) + image.save(f"image_sharp_claws_{seed}.jpg") +``` + +|![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)| +|-|-|-| + +Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model! \ No newline at end of file diff --git a/docs/en/Research_Tutorial/train_from_scratch.py b/docs/en/Research_Tutorial/train_from_scratch.py new file mode 100644 index 0000000..622e091 --- /dev/null +++ b/docs/en/Research_Tutorial/train_from_scratch.py @@ -0,0 +1,341 @@ +import torch, accelerate +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat + +from transformers import AutoProcessor, AutoTokenizer +from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model +from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task +from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit +from diffsynth.models.general_modules import TimestepEmbeddings +from diffsynth.models.z_image_text_encoder import ZImageTextEncoder +from diffsynth.models.flux2_vae import Flux2VAE + + +class AAAPositionalEmbedding(torch.nn.Module): + def __init__(self, height=16, width=16, dim=1024): + super().__init__() + self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width))) + self.text_emb = torch.nn.Parameter(torch.randn((dim,))) + + def forward(self, image, text): + height, width = image.shape[-2:] + image_emb = self.image_emb.to(device=image.device, dtype=image.dtype) + image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear") + image_emb = rearrange(image_emb, "B C H W -> B (H W) C") + text_emb = self.text_emb.to(device=text.device, dtype=text.dtype) + text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1]) + emb = torch.concat([image_emb, text_emb], dim=1) + return emb + + +class AAABlock(torch.nn.Module): + def __init__(self, dim=1024, num_heads=32): + super().__init__() + self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.to_q = torch.nn.Linear(dim, dim) + self.to_k = torch.nn.Linear(dim, dim) + self.to_v = torch.nn.Linear(dim, dim) + self.to_out = torch.nn.Linear(dim, dim) + self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.ff = torch.nn.Sequential( + torch.nn.Linear(dim, dim*3), + torch.nn.SiLU(), + torch.nn.Linear(dim*3, dim), + ) + self.to_gate = torch.nn.Linear(dim, dim * 2) + self.num_heads = num_heads + + def attention(self, emb, pos_emb): + emb = self.norm_attn(emb + pos_emb) + q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb) + emb = attention_forward( + q, k, v, + q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)", + dims={"n": self.num_heads}, + ) + emb = self.to_out(emb) + return emb + + def feed_forward(self, emb, pos_emb): + emb = self.norm_mlp(emb + pos_emb) + emb = self.ff(emb) + return emb + + def forward(self, emb, pos_emb, t_emb): + gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1) + emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn) + emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp) + return emb + + +class AAADiT(torch.nn.Module): + def __init__(self, dim=1024): + super().__init__() + self.pos_embedder = AAAPositionalEmbedding(dim=dim) + self.timestep_embedder = TimestepEmbeddings(256, dim) + self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim)) + self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim)) + self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)]) + self.proj_out = torch.nn.Linear(dim, 128) + + def forward( + self, + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + pos_emb = self.pos_embedder(latents, prompt_embeds) + t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1) + image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C")) + text = self.text_embedder(prompt_embeds) + emb = torch.concat([image, text], dim=1) + for block_id, block in enumerate(self.blocks): + emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + emb=emb, + pos_emb=pos_emb, + t_emb=t_emb, + ) + emb = emb[:, :latents.shape[-1] * latents.shape[-2]] + emb = self.proj_out(emb) + emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1]) + return emb + + +class AAAImagePipeline(BasePipeline): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: ZImageTextEncoder = None + self.dit: AAADiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + AAAUnit_PromptEmbedder(), + AAAUnit_NoiseInitializer(), + AAAUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_aaa + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("aaa_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class AAAUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + self.hidden_states_layers = (-1,) + + def process(self, pipe: AAAImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + text = pipe.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device) + output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False) + prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1) + return {"prompt_embeds": prompt_embeds} + + +class AAAUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class AAAUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: AAAImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +def model_fn_aaa( + dit: AAADiT, + latents=None, + prompt_embeds=None, + timestep=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + model_output = dit( + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + return model_output + + +class AAATrainingModule(DiffusionTrainingModule): + def __init__(self, device): + super().__init__() + self.pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern=""), + ) + self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device) + self.pipe.freeze_except(["dit"]) + self.pipe.scheduler.set_timesteps(1000, training=True) + + def forward(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "cfg_scale": 1, + "use_gradient_checkpointing": False, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss + + +if __name__ == "__main__": + accelerator = accelerate.Accelerator(gradient_accumulation_steps=1) + dataset = UnifiedDataset( + base_path="data/images", + metadata_path="data/metadata_merged.csv", + max_data_items=10000000, + data_file_keys=("image",), + main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256) + ) + model = AAATrainingModule(device=accelerator.device) + model_logger = ModelLogger( + "models/AAA/v1", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=2e-4, + num_workers=4, + save_steps=50000, + num_epochs=999999, + ) \ No newline at end of file diff --git a/docs/zh/README.md b/docs/zh/README.md index edcef50..c02665f 100644 --- a/docs/zh/README.md +++ b/docs/zh/README.md @@ -77,7 +77,7 @@ graph LR; 本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。 -* 从零开始训练模型【coming soon】 +* [从零开始训练模型](/docs/zh/Research_Tutorial/train_from_scratch.md) * 推理改进优化技术【coming soon】 * 设计可控生成模型【coming soon】 * 创建新的训练范式【coming soon】 diff --git a/docs/zh/Research_Tutorial/train_from_scratch.md b/docs/zh/Research_Tutorial/train_from_scratch.md new file mode 100644 index 0000000..3d5c6d0 --- /dev/null +++ b/docs/zh/Research_Tutorial/train_from_scratch.md @@ -0,0 +1,477 @@ +# 从零开始训练模型 + +DiffSynth-Studio 的训练引擎支持从零开始训练基础模型,本文介绍如何从零开始训练一个参数量仅为 0.1B 的小型文生图模型。 + +## 1. 构建模型结构 + +### 1.1 Diffusion 模型 + +从 UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) 到 DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206),Diffusion 的主流模型结构经历了多次演变。通常,一个 Diffusion 模型的输入包括: + +* 图像张量(`latents`):图像的编码,由 VAE 模型产生,含有部分噪声 +* 文本张量(`prompt_embeds`):文本的编码,由文本编码器产生 +* 时间步(`timestep`):标量,用于标记当前处于 Diffusion 过程的哪个阶段 + +模型的输出是与图像张量形状相同的张量,表示模型预测的去噪方向,关于 Diffusion 模型理论的细节,请参考 [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md)。在本文中,我们构建一个仅含 0.1B 参数的 DiT 模型:`AAADiT`。 + +
+模型结构代码 + +```python +import torch, accelerate +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat + +from transformers import AutoProcessor, AutoTokenizer +from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model +from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task +from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit +from diffsynth.models.general_modules import TimestepEmbeddings +from diffsynth.models.z_image_text_encoder import ZImageTextEncoder +from diffsynth.models.flux2_vae import Flux2VAE + + +class AAAPositionalEmbedding(torch.nn.Module): + def __init__(self, height=16, width=16, dim=1024): + super().__init__() + self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width))) + self.text_emb = torch.nn.Parameter(torch.randn((dim,))) + + def forward(self, image, text): + height, width = image.shape[-2:] + image_emb = self.image_emb.to(device=image.device, dtype=image.dtype) + image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear") + image_emb = rearrange(image_emb, "B C H W -> B (H W) C") + text_emb = self.text_emb.to(device=text.device, dtype=text.dtype) + text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1]) + emb = torch.concat([image_emb, text_emb], dim=1) + return emb + + +class AAABlock(torch.nn.Module): + def __init__(self, dim=1024, num_heads=32): + super().__init__() + self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.to_q = torch.nn.Linear(dim, dim) + self.to_k = torch.nn.Linear(dim, dim) + self.to_v = torch.nn.Linear(dim, dim) + self.to_out = torch.nn.Linear(dim, dim) + self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.ff = torch.nn.Sequential( + torch.nn.Linear(dim, dim*3), + torch.nn.SiLU(), + torch.nn.Linear(dim*3, dim), + ) + self.to_gate = torch.nn.Linear(dim, dim * 2) + self.num_heads = num_heads + + def attention(self, emb, pos_emb): + emb = self.norm_attn(emb + pos_emb) + q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb) + emb = attention_forward( + q, k, v, + q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)", + dims={"n": self.num_heads}, + ) + emb = self.to_out(emb) + return emb + + def feed_forward(self, emb, pos_emb): + emb = self.norm_mlp(emb + pos_emb) + emb = self.ff(emb) + return emb + + def forward(self, emb, pos_emb, t_emb): + gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1) + emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn) + emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp) + return emb + + +class AAADiT(torch.nn.Module): + def __init__(self, dim=1024): + super().__init__() + self.pos_embedder = AAAPositionalEmbedding(dim=dim) + self.timestep_embedder = TimestepEmbeddings(256, dim) + self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim)) + self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim)) + self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)]) + self.proj_out = torch.nn.Linear(dim, 128) + + def forward( + self, + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + pos_emb = self.pos_embedder(latents, prompt_embeds) + t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1) + image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C")) + text = self.text_embedder(prompt_embeds) + emb = torch.concat([image, text], dim=1) + for block_id, block in enumerate(self.blocks): + emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + emb=emb, + pos_emb=pos_emb, + t_emb=t_emb, + ) + emb = emb[:, :latents.shape[-1] * latents.shape[-2]] + emb = self.proj_out(emb) + emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1]) + return emb +``` + +
+ +### 1.2 编解码器模型 + +除了用于去噪的 Diffusion 模型以外,我们还需要另外两个模型: + +* 文本编码器:用于将文本编码为张量。我们采用 [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) 模型。 +* VAE 编解码器:编码器部分用于将图像编码为张量,解码器部分用于将图像张量解码为图像。我们采用 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 中的 VAE 模型。 + +这两个模型的结构都已集成在 DiffSynth-Studio 中,分别位于 [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) 和 [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py),因此我们不需要修改任何代码。 + +## 2. 构建 Pipeline + +我们在文档 [接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md) 中介绍了如何构建一个模型 Pipeline,对于本文中的模型,我们也需要构建一个 Pipeline,连接文本编码器、Diffusion 模型、VAE 编解码器。 + +
+Pipeline 代码 + +```python +class AAAImagePipeline(BasePipeline): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: ZImageTextEncoder = None + self.dit: AAADiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + AAAUnit_PromptEmbedder(), + AAAUnit_NoiseInitializer(), + AAAUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_aaa + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("aaa_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class AAAUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + self.hidden_states_layers = (-1,) + + def process(self, pipe: AAAImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + text = pipe.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device) + output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False) + prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1) + return {"prompt_embeds": prompt_embeds} + + +class AAAUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class AAAUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: AAAImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +def model_fn_aaa( + dit: AAADiT, + latents=None, + prompt_embeds=None, + timestep=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + model_output = dit( + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + return model_output +``` + +
+ +## 3. 准备数据集 + +为了快速验证训练效果,我们使用数据集 [宝可梦-第一世代](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1),这个数据集转载自开源项目 [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh),包含从妙蛙种子到梦幻的 151 个第一世代宝可梦。如果你想使用其他数据集,请参考文档 [准备数据集](/docs/zh/Pipeline_Usage/Model_Training.md#准备数据集) 和 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md)。 + +```shell +modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data +``` + +### 4. 开始训练 + +训练过程可使用 Pipeline 快速实现,我们已将完整的代码放在 [/docs/zh/Research_Tutorial/train_from_scratch.py](/docs/zh/Research_Tutorial/train_from_scratch.py),可直接通过 `python docs/zh/Research_Tutorial/train_from_scratch.py` 开始单 GPU 训练。 + +如需开启多 GPU 并行训练,请运行 `accelerate config` 设置相关参数,然后使用命令 `accelerate launch docs/zh/Research_Tutorial/train_from_scratch.py` 开始训练。 + +这个训练脚本没有设置停止条件,请在需要时手动关闭。模型在训练大约 6 万步后收敛,单 GPU 训练需要 10~20 小时。 + + +
+训练代码 + +```python +class AAATrainingModule(DiffusionTrainingModule): + def __init__(self, device): + super().__init__() + self.pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern=""), + ) + self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device) + self.pipe.freeze_except(["dit"]) + self.pipe.scheduler.set_timesteps(1000, training=True) + + def forward(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "cfg_scale": 1, + "use_gradient_checkpointing": False, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss + + +if __name__ == "__main__": + accelerator = accelerate.Accelerator(gradient_accumulation_steps=1) + dataset = UnifiedDataset( + base_path="data/images", + metadata_path="data/metadata_merged.csv", + max_data_items=10000000, + data_file_keys=("image",), + main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256) + ) + model = AAATrainingModule(device=accelerator.device) + model_logger = ModelLogger( + "models/AAA/v1", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=2e-4, + num_workers=4, + save_steps=50000, + num_epochs=999999, + ) +``` + +
+ +## 5. 验证训练效果 + +如果你不想等待模型训练完成,可以直接下载[我们预先训练好的模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel)。 + +```shell +modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel +``` + +加载模型 + +```python +from diffsynth import load_model + +pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern=""), +) +pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda") +``` + +模型推理,生成第一世代宝可梦“御三家”,此时模型生成的图像内容与训练数据基本一致。 + +```python +for seed, prompt in enumerate([ + "green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws", + "orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws", + "蓝色,米色,棕色,乌龟,水系,龟壳,大眼睛,短四肢,卷曲尾巴", +]): + image = pipe( + prompt=prompt, + negative_prompt=" ", + num_inference_steps=30, + cfg_scale=10, + seed=seed, + height=256, width=256, + ) + image.save(f"image_{seed}.jpg") +``` + +|![Image](https://github.com/user-attachments/assets/3c620fbf-5d28-4a1a-b887-519d85ac7d1c)|![Image](https://github.com/user-attachments/assets/909efd4c-9e61-4b33-9321-39da0e499b00)|![Image](https://github.com/user-attachments/assets/f3474bcd-b474-4a90-a1ea-579f67e161e3)| +|-|-|-| + +模型推理,生成具有“锐利爪子”的宝可梦,此时不同的随机种子能够产生不同的图像结果。 + +```python +for seed, prompt in enumerate([ + "sharp claws", + "sharp claws", + "sharp claws", +]): + image = pipe( + prompt=prompt, + negative_prompt=" ", + num_inference_steps=30, + cfg_scale=10, + seed=seed+4, + height=256, width=256, + ) + image.save(f"image_sharp_claws_{seed}.jpg") +``` + +|![Image](https://github.com/user-attachments/assets/94862edd-96ae-4276-a38f-795249f11a13)|![Image](https://github.com/user-attachments/assets/b2291f23-20ba-42de-8bfd-76cb4afc6eea)|![Image](https://github.com/user-attachments/assets/f2aab9a4-85ec-498e-8039-648b1289796e)| +|-|-|-| + +现在,我们获得了一个 0.1B 的小型文生图模型,这个模型已经能够生成 151 个宝可梦,但无法生成其他图像内容。如果在此基础上增加数据量、模型参数量、GPU 数量,你就可以训练出一个更强大的文生图模型! diff --git a/docs/zh/Research_Tutorial/train_from_scratch.py b/docs/zh/Research_Tutorial/train_from_scratch.py new file mode 100644 index 0000000..622e091 --- /dev/null +++ b/docs/zh/Research_Tutorial/train_from_scratch.py @@ -0,0 +1,341 @@ +import torch, accelerate +from PIL import Image +from typing import Union +from tqdm import tqdm +from einops import rearrange, repeat + +from transformers import AutoProcessor, AutoTokenizer +from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model +from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task +from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit +from diffsynth.models.general_modules import TimestepEmbeddings +from diffsynth.models.z_image_text_encoder import ZImageTextEncoder +from diffsynth.models.flux2_vae import Flux2VAE + + +class AAAPositionalEmbedding(torch.nn.Module): + def __init__(self, height=16, width=16, dim=1024): + super().__init__() + self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width))) + self.text_emb = torch.nn.Parameter(torch.randn((dim,))) + + def forward(self, image, text): + height, width = image.shape[-2:] + image_emb = self.image_emb.to(device=image.device, dtype=image.dtype) + image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear") + image_emb = rearrange(image_emb, "B C H W -> B (H W) C") + text_emb = self.text_emb.to(device=text.device, dtype=text.dtype) + text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1]) + emb = torch.concat([image_emb, text_emb], dim=1) + return emb + + +class AAABlock(torch.nn.Module): + def __init__(self, dim=1024, num_heads=32): + super().__init__() + self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.to_q = torch.nn.Linear(dim, dim) + self.to_k = torch.nn.Linear(dim, dim) + self.to_v = torch.nn.Linear(dim, dim) + self.to_out = torch.nn.Linear(dim, dim) + self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False) + self.ff = torch.nn.Sequential( + torch.nn.Linear(dim, dim*3), + torch.nn.SiLU(), + torch.nn.Linear(dim*3, dim), + ) + self.to_gate = torch.nn.Linear(dim, dim * 2) + self.num_heads = num_heads + + def attention(self, emb, pos_emb): + emb = self.norm_attn(emb + pos_emb) + q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb) + emb = attention_forward( + q, k, v, + q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)", + dims={"n": self.num_heads}, + ) + emb = self.to_out(emb) + return emb + + def feed_forward(self, emb, pos_emb): + emb = self.norm_mlp(emb + pos_emb) + emb = self.ff(emb) + return emb + + def forward(self, emb, pos_emb, t_emb): + gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1) + emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn) + emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp) + return emb + + +class AAADiT(torch.nn.Module): + def __init__(self, dim=1024): + super().__init__() + self.pos_embedder = AAAPositionalEmbedding(dim=dim) + self.timestep_embedder = TimestepEmbeddings(256, dim) + self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim)) + self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim)) + self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)]) + self.proj_out = torch.nn.Linear(dim, 128) + + def forward( + self, + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + ): + pos_emb = self.pos_embedder(latents, prompt_embeds) + t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1) + image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C")) + text = self.text_embedder(prompt_embeds) + emb = torch.concat([image, text], dim=1) + for block_id, block in enumerate(self.blocks): + emb = gradient_checkpoint_forward( + block, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + emb=emb, + pos_emb=pos_emb, + t_emb=t_emb, + ) + emb = emb[:, :latents.shape[-1] * latents.shape[-2]] + emb = self.proj_out(emb) + emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1]) + return emb + + +class AAAImagePipeline(BasePipeline): + def __init__(self, device="cuda", torch_dtype=torch.bfloat16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=16, width_division_factor=16, + ) + self.scheduler = FlowMatchScheduler("FLUX.2") + self.text_encoder: ZImageTextEncoder = None + self.dit: AAADiT = None + self.vae: Flux2VAE = None + self.tokenizer: AutoProcessor = None + self.in_iteration_models = ("dit",) + self.units = [ + AAAUnit_PromptEmbedder(), + AAAUnit_NoiseInitializer(), + AAAUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_aaa + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.bfloat16, + device: Union[str, torch.device] = "cuda", + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = None, + vram_limit: float = None, + ): + # Initialize pipeline + pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype) + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + + # Fetch models + pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") + pipe.dit = model_pool.fetch_model("aaa_dit") + pipe.vae = model_pool.fetch_model("flux2_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + + # VRAM Management + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + # Prompt + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 1.0, + # Image + input_image: Image.Image = None, + denoising_strength: float = 1.0, + # Shape + height: int = 1024, + width: int = 1024, + # Randomness + seed: int = None, + rand_device: str = "cpu", + # Steps + num_inference_steps: int = 30, + # Progress bar + progress_bar_cmd = tqdm, + ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + + # Parameters + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "cfg_scale": cfg_scale, + "input_image": input_image, "denoising_strength": denoising_strength, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "num_inference_steps": num_inference_steps, + } + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + + # Denoise + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) + + # Decode + self.load_models_to_device(['vae']) + image = self.vae.decode(inputs_shared["latents"]) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class AAAUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + self.hidden_states_layers = (-1,) + + def process(self, pipe: AAAImagePipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + text = pipe.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device) + output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False) + prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1) + return {"prompt_embeds": prompt_embeds} + + +class AAAUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device): + noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) + return {"noise": noise} + + +class AAAUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("input_image", "noise"), + output_params=("latents", "input_latents"), + onload_model_names=("vae",) + ) + + def process(self, pipe: AAAImagePipeline, input_image, noise): + if input_image is None: + return {"latents": noise, "input_latents": None} + pipe.load_models_to_device(['vae']) + image = pipe.preprocess_image(input_image) + input_latents = pipe.vae.encode(image) + if pipe.scheduler.training: + return {"latents": noise, "input_latents": input_latents} + else: + latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) + return {"latents": latents, "input_latents": input_latents} + + +def model_fn_aaa( + dit: AAADiT, + latents=None, + prompt_embeds=None, + timestep=None, + use_gradient_checkpointing=False, + use_gradient_checkpointing_offload=False, + **kwargs, +): + model_output = dit( + latents, + prompt_embeds, + timestep, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + return model_output + + +class AAATrainingModule(DiffusionTrainingModule): + def __init__(self, device): + super().__init__() + self.pipe = AAAImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern=""), + ) + self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device) + self.pipe.freeze_except(["dit"]) + self.pipe.scheduler.set_timesteps(1000, training=True) + + def forward(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "cfg_scale": 1, + "use_gradient_checkpointing": False, + "use_gradient_checkpointing_offload": False, + } + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) + return loss + + +if __name__ == "__main__": + accelerator = accelerate.Accelerator(gradient_accumulation_steps=1) + dataset = UnifiedDataset( + base_path="data/images", + metadata_path="data/metadata_merged.csv", + max_data_items=10000000, + data_file_keys=("image",), + main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256) + ) + model = AAATrainingModule(device=accelerator.device) + model_logger = ModelLogger( + "models/AAA/v1", + remove_prefix_in_ckpt="pipe.dit.", + ) + launch_training_task( + accelerator, dataset, model, model_logger, + learning_rate=2e-4, + num_workers=4, + save_steps=50000, + num_epochs=999999, + ) \ No newline at end of file