mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
Merge branch 'main' into ltx-2
This commit is contained in:
@@ -85,6 +85,7 @@ graph LR;
|
||||
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
||||
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||
|
||||
@@ -50,9 +50,14 @@ image.save("image.jpg")
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - |
|
||||
| [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) | [code](/examples/z_image/model_inference/Z-Image-Turbo.py) | [code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/full/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py) |
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||
|
||||
Special Training Scripts:
|
||||
|
||||
@@ -75,6 +80,9 @@ Input parameters for `ZImagePipeline` inference include:
|
||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
|
||||
* `num_inference_steps`: Number of inference steps, default value is 8.
|
||||
* `controlnet_inputs`: Inputs for ControlNet models.
|
||||
* `edit_image`: Edit images for image editing models, supporting multiple images.
|
||||
* `positive_only_lora`: LoRA weights used only in positive prompts.
|
||||
|
||||
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ This section introduces the independent core module `diffsynth.core` in `DiffSyn
|
||||
|
||||
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
|
||||
|
||||
* Training models from scratch 【coming soon】
|
||||
* [Training models from scratch](/docs/en/Research_Tutorial/train_from_scratch.md)
|
||||
* Inference improvement techniques 【coming soon】
|
||||
* Designing controllable generation models 【coming soon】
|
||||
* Creating new training paradigms 【coming soon】
|
||||
|
||||
476
docs/en/Research_Tutorial/train_from_scratch.md
Normal file
476
docs/en/Research_Tutorial/train_from_scratch.md
Normal file
@@ -0,0 +1,476 @@
|
||||
# Training Models from Scratch
|
||||
|
||||
DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch.
|
||||
|
||||
## 1. Building Model Architecture
|
||||
|
||||
### 1.1 Diffusion Model
|
||||
|
||||
From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include:
|
||||
|
||||
* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise
|
||||
* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder
|
||||
* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at
|
||||
|
||||
The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`.
|
||||
|
||||
<details>
|
||||
<summary>Model Architecture Code</summary>
|
||||
|
||||
```python
|
||||
import torch, accelerate
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||
from diffsynth.models.flux2_vae import Flux2VAE
|
||||
|
||||
|
||||
class AAAPositionalEmbedding(torch.nn.Module):
|
||||
def __init__(self, height=16, width=16, dim=1024):
|
||||
super().__init__()
|
||||
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||
|
||||
def forward(self, image, text):
|
||||
height, width = image.shape[-2:]
|
||||
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||
return emb
|
||||
|
||||
|
||||
class AAABlock(torch.nn.Module):
|
||||
def __init__(self, dim=1024, num_heads=32):
|
||||
super().__init__()
|
||||
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.to_q = torch.nn.Linear(dim, dim)
|
||||
self.to_k = torch.nn.Linear(dim, dim)
|
||||
self.to_v = torch.nn.Linear(dim, dim)
|
||||
self.to_out = torch.nn.Linear(dim, dim)
|
||||
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*3),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(dim*3, dim),
|
||||
)
|
||||
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def attention(self, emb, pos_emb):
|
||||
emb = self.norm_attn(emb + pos_emb)
|
||||
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||
emb = attention_forward(
|
||||
q, k, v,
|
||||
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||
dims={"n": self.num_heads},
|
||||
)
|
||||
emb = self.to_out(emb)
|
||||
return emb
|
||||
|
||||
def feed_forward(self, emb, pos_emb):
|
||||
emb = self.norm_mlp(emb + pos_emb)
|
||||
emb = self.ff(emb)
|
||||
return emb
|
||||
|
||||
def forward(self, emb, pos_emb, t_emb):
|
||||
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||
return emb
|
||||
|
||||
|
||||
class AAADiT(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||
self.proj_out = torch.nn.Linear(dim, 128)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||
text = self.text_embedder(prompt_embeds)
|
||||
emb = torch.concat([image, text], dim=1)
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
emb = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
emb=emb,
|
||||
pos_emb=pos_emb,
|
||||
t_emb=t_emb,
|
||||
)
|
||||
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||
emb = self.proj_out(emb)
|
||||
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||
return emb
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 1.2 Encoder-Decoder Models
|
||||
|
||||
Besides the Diffusion model used for denoising, we also need two other models:
|
||||
|
||||
* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.
|
||||
* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).
|
||||
|
||||
The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py), so we don't need to modify any code.
|
||||
|
||||
## 2. Building Pipeline
|
||||
|
||||
We introduced how to build a model Pipeline in the document [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.
|
||||
|
||||
<details>
|
||||
<summary>Pipeline Code</summary>
|
||||
|
||||
```python
|
||||
class AAAImagePipeline(BasePipeline):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: AAADiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoProcessor = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AAAUnit_PromptEmbedder(),
|
||||
AAAUnit_NoiseInitializer(),
|
||||
AAAUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_aaa
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"])
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
self.hidden_states_layers = (-1,)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
text = pipe.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(image)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
def model_fn_aaa(
|
||||
dit: AAADiT,
|
||||
latents=None,
|
||||
prompt_embeds=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
model_output = dit(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
return model_output
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 3. Preparing Dataset
|
||||
|
||||
To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](/docs/en/Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md).
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
|
||||
```
|
||||
|
||||
### 4. Start Training
|
||||
|
||||
The training process can be quickly implemented using Pipeline. We have placed the complete code at [/docs/en/Research_Tutorial/train_from_scratch.py](/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.
|
||||
|
||||
To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.
|
||||
|
||||
This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.
|
||||
|
||||
<details>
|
||||
<summary>Training Code</summary>
|
||||
|
||||
```python
|
||||
class AAATrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||
)
|
||||
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||
self.pipe.freeze_except(["dit"])
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
def forward(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
"cfg_scale": 1,
|
||||
"use_gradient_checkpointing": False,
|
||||
"use_gradient_checkpointing_offload": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||
dataset = UnifiedDataset(
|
||||
base_path="data/images",
|
||||
metadata_path="data/metadata_merged.csv",
|
||||
max_data_items=10000000,
|
||||
data_file_keys=("image",),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||
)
|
||||
model = AAATrainingModule(device=accelerator.device)
|
||||
model_logger = ModelLogger(
|
||||
"models/AAA/v1",
|
||||
remove_prefix_in_ckpt="pipe.dit.",
|
||||
)
|
||||
launch_training_task(
|
||||
accelerator, dataset, model, model_logger,
|
||||
learning_rate=2e-4,
|
||||
num_workers=4,
|
||||
save_steps=50000,
|
||||
num_epochs=999999,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 5. Verifying Training Results
|
||||
|
||||
If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel).
|
||||
|
||||
```shell
|
||||
modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
|
||||
```
|
||||
|
||||
Loading the model
|
||||
|
||||
```python
|
||||
from diffsynth import load_model
|
||||
|
||||
pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||
)
|
||||
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
|
||||
```
|
||||
|
||||
Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data.
|
||||
|
||||
```python
|
||||
for seed, prompt in enumerate([
|
||||
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
|
||||
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
|
||||
"blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail",
|
||||
]):
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=" ",
|
||||
num_inference_steps=30,
|
||||
cfg_scale=10,
|
||||
seed=seed,
|
||||
height=256, width=256,
|
||||
)
|
||||
image.save(f"image_{seed}.jpg")
|
||||
```
|
||||
|
||||
||||
|
||||
|-|-|-|
|
||||
|
||||
Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results.
|
||||
|
||||
```python
|
||||
for seed, prompt in enumerate([
|
||||
"sharp claws",
|
||||
"sharp claws",
|
||||
"sharp claws",
|
||||
]):
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=" ",
|
||||
num_inference_steps=30,
|
||||
cfg_scale=10,
|
||||
seed=seed+4,
|
||||
height=256, width=256,
|
||||
)
|
||||
image.save(f"image_sharp_claws_{seed}.jpg")
|
||||
```
|
||||
|
||||
||||
|
||||
|-|-|-|
|
||||
|
||||
Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!
|
||||
341
docs/en/Research_Tutorial/train_from_scratch.py
Normal file
341
docs/en/Research_Tutorial/train_from_scratch.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import torch, accelerate
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||
from diffsynth.models.flux2_vae import Flux2VAE
|
||||
|
||||
|
||||
class AAAPositionalEmbedding(torch.nn.Module):
|
||||
def __init__(self, height=16, width=16, dim=1024):
|
||||
super().__init__()
|
||||
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||
|
||||
def forward(self, image, text):
|
||||
height, width = image.shape[-2:]
|
||||
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||
return emb
|
||||
|
||||
|
||||
class AAABlock(torch.nn.Module):
|
||||
def __init__(self, dim=1024, num_heads=32):
|
||||
super().__init__()
|
||||
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.to_q = torch.nn.Linear(dim, dim)
|
||||
self.to_k = torch.nn.Linear(dim, dim)
|
||||
self.to_v = torch.nn.Linear(dim, dim)
|
||||
self.to_out = torch.nn.Linear(dim, dim)
|
||||
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*3),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(dim*3, dim),
|
||||
)
|
||||
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def attention(self, emb, pos_emb):
|
||||
emb = self.norm_attn(emb + pos_emb)
|
||||
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||
emb = attention_forward(
|
||||
q, k, v,
|
||||
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||
dims={"n": self.num_heads},
|
||||
)
|
||||
emb = self.to_out(emb)
|
||||
return emb
|
||||
|
||||
def feed_forward(self, emb, pos_emb):
|
||||
emb = self.norm_mlp(emb + pos_emb)
|
||||
emb = self.ff(emb)
|
||||
return emb
|
||||
|
||||
def forward(self, emb, pos_emb, t_emb):
|
||||
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||
return emb
|
||||
|
||||
|
||||
class AAADiT(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||
self.proj_out = torch.nn.Linear(dim, 128)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||
text = self.text_embedder(prompt_embeds)
|
||||
emb = torch.concat([image, text], dim=1)
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
emb = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
emb=emb,
|
||||
pos_emb=pos_emb,
|
||||
t_emb=t_emb,
|
||||
)
|
||||
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||
emb = self.proj_out(emb)
|
||||
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||
return emb
|
||||
|
||||
|
||||
class AAAImagePipeline(BasePipeline):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: AAADiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoProcessor = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AAAUnit_PromptEmbedder(),
|
||||
AAAUnit_NoiseInitializer(),
|
||||
AAAUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_aaa
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"])
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
self.hidden_states_layers = (-1,)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
text = pipe.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(image)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
def model_fn_aaa(
|
||||
dit: AAADiT,
|
||||
latents=None,
|
||||
prompt_embeds=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
model_output = dit(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
return model_output
|
||||
|
||||
|
||||
class AAATrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||
)
|
||||
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||
self.pipe.freeze_except(["dit"])
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
def forward(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
"cfg_scale": 1,
|
||||
"use_gradient_checkpointing": False,
|
||||
"use_gradient_checkpointing_offload": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||
dataset = UnifiedDataset(
|
||||
base_path="data/images",
|
||||
metadata_path="data/metadata_merged.csv",
|
||||
max_data_items=10000000,
|
||||
data_file_keys=("image",),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||
)
|
||||
model = AAATrainingModule(device=accelerator.device)
|
||||
model_logger = ModelLogger(
|
||||
"models/AAA/v1",
|
||||
remove_prefix_in_ckpt="pipe.dit.",
|
||||
)
|
||||
launch_training_task(
|
||||
accelerator, dataset, model, model_logger,
|
||||
learning_rate=2e-4,
|
||||
num_workers=4,
|
||||
save_steps=50000,
|
||||
num_epochs=999999,
|
||||
)
|
||||
@@ -6,7 +6,7 @@ This document introduces the basic principles of Diffusion models to help you un
|
||||
|
||||
Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$.
|
||||
|
||||
(Figure)
|
||||

|
||||
|
||||
This process is intuitive, but to understand the details, we need to answer several questions:
|
||||
|
||||
@@ -28,7 +28,7 @@ As for the intermediate values $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma
|
||||
|
||||
At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$.
|
||||
|
||||
(Figure)
|
||||

|
||||
|
||||
## How is the iterative denoising computation performed?
|
||||
|
||||
@@ -40,8 +40,6 @@ Before understanding the iterative denoising computation, we need to clarify wha
|
||||
|
||||
Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure.
|
||||
|
||||
(Figure)
|
||||
|
||||
The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising).
|
||||
|
||||
Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$:
|
||||
@@ -91,8 +89,6 @@ After understanding the iterative denoising process, we next consider how to tra
|
||||
|
||||
The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training.
|
||||
|
||||
(Figure)
|
||||
|
||||
The following is pseudocode for the training process:
|
||||
|
||||
> Obtain data sample $x_0$ and guidance condition $c$ from the dataset
|
||||
@@ -113,7 +109,7 @@ The following is pseudocode for the training process:
|
||||
|
||||
From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model.
|
||||
|
||||
(Figure)
|
||||

|
||||
|
||||
### Data Encoder-Decoder
|
||||
|
||||
|
||||
Reference in New Issue
Block a user