mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
add research tutorial sec 1
This commit is contained in:
476
docs/en/Research_Tutorial/train_from_scratch.md
Normal file
476
docs/en/Research_Tutorial/train_from_scratch.md
Normal file
@@ -0,0 +1,476 @@
|
||||
# Training Models from Scratch
|
||||
|
||||
DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch.
|
||||
|
||||
## 1. Building Model Architecture
|
||||
|
||||
### 1.1 Diffusion Model
|
||||
|
||||
From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include:
|
||||
|
||||
* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise
|
||||
* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder
|
||||
* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at
|
||||
|
||||
The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`.
|
||||
|
||||
<details>
|
||||
<summary>Model Architecture Code</summary>
|
||||
|
||||
```python
|
||||
import torch, accelerate
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||
from diffsynth.models.flux2_vae import Flux2VAE
|
||||
|
||||
|
||||
class AAAPositionalEmbedding(torch.nn.Module):
|
||||
def __init__(self, height=16, width=16, dim=1024):
|
||||
super().__init__()
|
||||
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||
|
||||
def forward(self, image, text):
|
||||
height, width = image.shape[-2:]
|
||||
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||
return emb
|
||||
|
||||
|
||||
class AAABlock(torch.nn.Module):
|
||||
def __init__(self, dim=1024, num_heads=32):
|
||||
super().__init__()
|
||||
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.to_q = torch.nn.Linear(dim, dim)
|
||||
self.to_k = torch.nn.Linear(dim, dim)
|
||||
self.to_v = torch.nn.Linear(dim, dim)
|
||||
self.to_out = torch.nn.Linear(dim, dim)
|
||||
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*3),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(dim*3, dim),
|
||||
)
|
||||
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def attention(self, emb, pos_emb):
|
||||
emb = self.norm_attn(emb + pos_emb)
|
||||
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||
emb = attention_forward(
|
||||
q, k, v,
|
||||
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||
dims={"n": self.num_heads},
|
||||
)
|
||||
emb = self.to_out(emb)
|
||||
return emb
|
||||
|
||||
def feed_forward(self, emb, pos_emb):
|
||||
emb = self.norm_mlp(emb + pos_emb)
|
||||
emb = self.ff(emb)
|
||||
return emb
|
||||
|
||||
def forward(self, emb, pos_emb, t_emb):
|
||||
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||
return emb
|
||||
|
||||
|
||||
class AAADiT(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||
self.proj_out = torch.nn.Linear(dim, 128)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||
text = self.text_embedder(prompt_embeds)
|
||||
emb = torch.concat([image, text], dim=1)
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
emb = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
emb=emb,
|
||||
pos_emb=pos_emb,
|
||||
t_emb=t_emb,
|
||||
)
|
||||
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||
emb = self.proj_out(emb)
|
||||
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||
return emb
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 1.2 Encoder-Decoder Models
|
||||
|
||||
Besides the Diffusion model used for denoising, we also need two other models:
|
||||
|
||||
* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.
|
||||
* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).
|
||||
|
||||
The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py), so we don't need to modify any code.
|
||||
|
||||
## 2. Building Pipeline
|
||||
|
||||
We introduced how to build a model Pipeline in the document [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.
|
||||
|
||||
<details>
|
||||
<summary>Pipeline Code</summary>
|
||||
|
||||
```python
|
||||
class AAAImagePipeline(BasePipeline):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: AAADiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoProcessor = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AAAUnit_PromptEmbedder(),
|
||||
AAAUnit_NoiseInitializer(),
|
||||
AAAUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_aaa
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"])
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
self.hidden_states_layers = (-1,)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
text = pipe.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(image)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
def model_fn_aaa(
|
||||
dit: AAADiT,
|
||||
latents=None,
|
||||
prompt_embeds=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
model_output = dit(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
return model_output
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 3. Preparing Dataset
|
||||
|
||||
To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](/docs/en/Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md).
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
|
||||
```
|
||||
|
||||
### 4. Start Training
|
||||
|
||||
The training process can be quickly implemented using Pipeline. We have placed the complete code at [/docs/en/Research_Tutorial/train_from_scratch.py](/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.
|
||||
|
||||
To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.
|
||||
|
||||
This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.
|
||||
|
||||
<details>
|
||||
<summary>Training Code</summary>
|
||||
|
||||
```python
|
||||
class AAATrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern=""),
|
||||
)
|
||||
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||
self.pipe.freeze_except(["dit"])
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
def forward(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
"cfg_scale": 1,
|
||||
"use_gradient_checkpointing": False,
|
||||
"use_gradient_checkpointing_offload": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||
dataset = UnifiedDataset(
|
||||
base_path="data/images",
|
||||
metadata_path="data/metadata_merged.csv",
|
||||
max_data_items=10000000,
|
||||
data_file_keys=("image",),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||
)
|
||||
model = AAATrainingModule(device=accelerator.device)
|
||||
model_logger = ModelLogger(
|
||||
"models/AAA/v1",
|
||||
remove_prefix_in_ckpt="pipe.dit.",
|
||||
)
|
||||
launch_training_task(
|
||||
accelerator, dataset, model, model_logger,
|
||||
learning_rate=2e-4,
|
||||
num_workers=4,
|
||||
save_steps=50000,
|
||||
num_epochs=999999,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 5. Verifying Training Results
|
||||
|
||||
If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel).
|
||||
|
||||
```shell
|
||||
modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
|
||||
```
|
||||
|
||||
Loading the model
|
||||
|
||||
```python
|
||||
from diffsynth import load_model
|
||||
|
||||
pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern=""),
|
||||
)
|
||||
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
|
||||
```
|
||||
|
||||
Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data.
|
||||
|
||||
```python
|
||||
for seed, prompt in enumerate([
|
||||
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
|
||||
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
|
||||
"blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail",
|
||||
]):
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=" ",
|
||||
num_inference_steps=30,
|
||||
cfg_scale=10,
|
||||
seed=seed,
|
||||
height=256, width=256,
|
||||
)
|
||||
image.save(f"image_{seed}.jpg")
|
||||
```
|
||||
|
||||
||||
|
||||
|-|-|-|
|
||||
|
||||
Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results.
|
||||
|
||||
```python
|
||||
for seed, prompt in enumerate([
|
||||
"sharp claws",
|
||||
"sharp claws",
|
||||
"sharp claws",
|
||||
]):
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=" ",
|
||||
num_inference_steps=30,
|
||||
cfg_scale=10,
|
||||
seed=seed+4,
|
||||
height=256, width=256,
|
||||
)
|
||||
image.save(f"image_sharp_claws_{seed}.jpg")
|
||||
```
|
||||
|
||||
||||
|
||||
|-|-|-|
|
||||
|
||||
Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!
|
||||
341
docs/en/Research_Tutorial/train_from_scratch.py
Normal file
341
docs/en/Research_Tutorial/train_from_scratch.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import torch, accelerate
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||
from diffsynth.models.flux2_vae import Flux2VAE
|
||||
|
||||
|
||||
class AAAPositionalEmbedding(torch.nn.Module):
|
||||
def __init__(self, height=16, width=16, dim=1024):
|
||||
super().__init__()
|
||||
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||
|
||||
def forward(self, image, text):
|
||||
height, width = image.shape[-2:]
|
||||
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||
return emb
|
||||
|
||||
|
||||
class AAABlock(torch.nn.Module):
|
||||
def __init__(self, dim=1024, num_heads=32):
|
||||
super().__init__()
|
||||
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.to_q = torch.nn.Linear(dim, dim)
|
||||
self.to_k = torch.nn.Linear(dim, dim)
|
||||
self.to_v = torch.nn.Linear(dim, dim)
|
||||
self.to_out = torch.nn.Linear(dim, dim)
|
||||
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*3),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(dim*3, dim),
|
||||
)
|
||||
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def attention(self, emb, pos_emb):
|
||||
emb = self.norm_attn(emb + pos_emb)
|
||||
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||
emb = attention_forward(
|
||||
q, k, v,
|
||||
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||
dims={"n": self.num_heads},
|
||||
)
|
||||
emb = self.to_out(emb)
|
||||
return emb
|
||||
|
||||
def feed_forward(self, emb, pos_emb):
|
||||
emb = self.norm_mlp(emb + pos_emb)
|
||||
emb = self.ff(emb)
|
||||
return emb
|
||||
|
||||
def forward(self, emb, pos_emb, t_emb):
|
||||
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||
return emb
|
||||
|
||||
|
||||
class AAADiT(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||
self.proj_out = torch.nn.Linear(dim, 128)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||
text = self.text_embedder(prompt_embeds)
|
||||
emb = torch.concat([image, text], dim=1)
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
emb = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
emb=emb,
|
||||
pos_emb=pos_emb,
|
||||
t_emb=t_emb,
|
||||
)
|
||||
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||
emb = self.proj_out(emb)
|
||||
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||
return emb
|
||||
|
||||
|
||||
class AAAImagePipeline(BasePipeline):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: AAADiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoProcessor = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AAAUnit_PromptEmbedder(),
|
||||
AAAUnit_NoiseInitializer(),
|
||||
AAAUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_aaa
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"])
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
self.hidden_states_layers = (-1,)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
text = pipe.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(image)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
def model_fn_aaa(
|
||||
dit: AAADiT,
|
||||
latents=None,
|
||||
prompt_embeds=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
model_output = dit(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
return model_output
|
||||
|
||||
|
||||
class AAATrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern=""),
|
||||
)
|
||||
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||
self.pipe.freeze_except(["dit"])
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
def forward(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
"cfg_scale": 1,
|
||||
"use_gradient_checkpointing": False,
|
||||
"use_gradient_checkpointing_offload": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||
dataset = UnifiedDataset(
|
||||
base_path="data/images",
|
||||
metadata_path="data/metadata_merged.csv",
|
||||
max_data_items=10000000,
|
||||
data_file_keys=("image",),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||
)
|
||||
model = AAATrainingModule(device=accelerator.device)
|
||||
model_logger = ModelLogger(
|
||||
"models/AAA/v1",
|
||||
remove_prefix_in_ckpt="pipe.dit.",
|
||||
)
|
||||
launch_training_task(
|
||||
accelerator, dataset, model, model_logger,
|
||||
learning_rate=2e-4,
|
||||
num_workers=4,
|
||||
save_steps=50000,
|
||||
num_epochs=999999,
|
||||
)
|
||||
Reference in New Issue
Block a user