mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:23:43 +00:00
341 lines
13 KiB
Python
341 lines
13 KiB
Python
import torch, accelerate
|
|
from PIL import Image
|
|
from typing import Union
|
|
from tqdm import tqdm
|
|
from einops import rearrange, repeat
|
|
|
|
from transformers import AutoProcessor, AutoTokenizer
|
|
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
|
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
|
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
|
from diffsynth.models.general_modules import TimestepEmbeddings
|
|
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
|
from diffsynth.models.flux2_vae import Flux2VAE
|
|
|
|
|
|
class AAAPositionalEmbedding(torch.nn.Module):
|
|
def __init__(self, height=16, width=16, dim=1024):
|
|
super().__init__()
|
|
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
|
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
|
|
|
def forward(self, image, text):
|
|
height, width = image.shape[-2:]
|
|
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
|
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
|
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
|
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
|
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
|
emb = torch.concat([image_emb, text_emb], dim=1)
|
|
return emb
|
|
|
|
|
|
class AAABlock(torch.nn.Module):
|
|
def __init__(self, dim=1024, num_heads=32):
|
|
super().__init__()
|
|
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
|
self.to_q = torch.nn.Linear(dim, dim)
|
|
self.to_k = torch.nn.Linear(dim, dim)
|
|
self.to_v = torch.nn.Linear(dim, dim)
|
|
self.to_out = torch.nn.Linear(dim, dim)
|
|
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
|
self.ff = torch.nn.Sequential(
|
|
torch.nn.Linear(dim, dim*3),
|
|
torch.nn.SiLU(),
|
|
torch.nn.Linear(dim*3, dim),
|
|
)
|
|
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
|
self.num_heads = num_heads
|
|
|
|
def attention(self, emb, pos_emb):
|
|
emb = self.norm_attn(emb + pos_emb)
|
|
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
|
emb = attention_forward(
|
|
q, k, v,
|
|
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
|
dims={"n": self.num_heads},
|
|
)
|
|
emb = self.to_out(emb)
|
|
return emb
|
|
|
|
def feed_forward(self, emb, pos_emb):
|
|
emb = self.norm_mlp(emb + pos_emb)
|
|
emb = self.ff(emb)
|
|
return emb
|
|
|
|
def forward(self, emb, pos_emb, t_emb):
|
|
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
|
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
|
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
|
return emb
|
|
|
|
|
|
class AAADiT(torch.nn.Module):
|
|
def __init__(self, dim=1024):
|
|
super().__init__()
|
|
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
|
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
|
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
|
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
|
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
|
self.proj_out = torch.nn.Linear(dim, 128)
|
|
|
|
def forward(
|
|
self,
|
|
latents,
|
|
prompt_embeds,
|
|
timestep,
|
|
use_gradient_checkpointing=False,
|
|
use_gradient_checkpointing_offload=False,
|
|
):
|
|
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
|
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
|
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
|
text = self.text_embedder(prompt_embeds)
|
|
emb = torch.concat([image, text], dim=1)
|
|
for block_id, block in enumerate(self.blocks):
|
|
emb = gradient_checkpoint_forward(
|
|
block,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
emb=emb,
|
|
pos_emb=pos_emb,
|
|
t_emb=t_emb,
|
|
)
|
|
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
|
emb = self.proj_out(emb)
|
|
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
|
return emb
|
|
|
|
|
|
class AAAImagePipeline(BasePipeline):
|
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
|
super().__init__(
|
|
device=device, torch_dtype=torch_dtype,
|
|
height_division_factor=16, width_division_factor=16,
|
|
)
|
|
self.scheduler = FlowMatchScheduler("FLUX.2")
|
|
self.text_encoder: ZImageTextEncoder = None
|
|
self.dit: AAADiT = None
|
|
self.vae: Flux2VAE = None
|
|
self.tokenizer: AutoProcessor = None
|
|
self.in_iteration_models = ("dit",)
|
|
self.units = [
|
|
AAAUnit_PromptEmbedder(),
|
|
AAAUnit_NoiseInitializer(),
|
|
AAAUnit_InputImageEmbedder(),
|
|
]
|
|
self.model_fn = model_fn_aaa
|
|
|
|
@staticmethod
|
|
def from_pretrained(
|
|
torch_dtype: torch.dtype = torch.bfloat16,
|
|
device: Union[str, torch.device] = "cuda",
|
|
model_configs: list[ModelConfig] = [],
|
|
tokenizer_config: ModelConfig = None,
|
|
vram_limit: float = None,
|
|
):
|
|
# Initialize pipeline
|
|
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
|
|
|
# Fetch models
|
|
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
|
pipe.dit = model_pool.fetch_model("aaa_dit")
|
|
pipe.vae = model_pool.fetch_model("flux2_vae")
|
|
if tokenizer_config is not None:
|
|
tokenizer_config.download_if_necessary()
|
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
|
|
|
# VRAM Management
|
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
|
return pipe
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
# Prompt
|
|
prompt: str,
|
|
negative_prompt: str = "",
|
|
cfg_scale: float = 1.0,
|
|
# Image
|
|
input_image: Image.Image = None,
|
|
denoising_strength: float = 1.0,
|
|
# Shape
|
|
height: int = 1024,
|
|
width: int = 1024,
|
|
# Randomness
|
|
seed: int = None,
|
|
rand_device: str = "cpu",
|
|
# Steps
|
|
num_inference_steps: int = 30,
|
|
# Progress bar
|
|
progress_bar_cmd = tqdm,
|
|
):
|
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
|
|
|
# Parameters
|
|
inputs_posi = {"prompt": prompt}
|
|
inputs_nega = {"negative_prompt": negative_prompt}
|
|
inputs_shared = {
|
|
"cfg_scale": cfg_scale,
|
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
|
"height": height, "width": width,
|
|
"seed": seed, "rand_device": rand_device,
|
|
"num_inference_steps": num_inference_steps,
|
|
}
|
|
for unit in self.units:
|
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
|
|
|
# Denoise
|
|
self.load_models_to_device(self.in_iteration_models)
|
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
noise_pred = self.cfg_guided_model_fn(
|
|
self.model_fn, cfg_scale,
|
|
inputs_shared, inputs_posi, inputs_nega,
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
)
|
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
|
|
|
# Decode
|
|
self.load_models_to_device(['vae'])
|
|
image = self.vae.decode(inputs_shared["latents"])
|
|
image = self.vae_output_to_image(image)
|
|
self.load_models_to_device([])
|
|
|
|
return image
|
|
|
|
|
|
class AAAUnit_PromptEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
seperate_cfg=True,
|
|
input_params_posi={"prompt": "prompt"},
|
|
input_params_nega={"prompt": "negative_prompt"},
|
|
output_params=("prompt_embeds",),
|
|
onload_model_names=("text_encoder",)
|
|
)
|
|
self.hidden_states_layers = (-1,)
|
|
|
|
def process(self, pipe: AAAImagePipeline, prompt):
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
text = pipe.tokenizer.apply_chat_template(
|
|
[{"role": "user", "content": prompt}],
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
enable_thinking=False,
|
|
)
|
|
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
|
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
|
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
|
return {"prompt_embeds": prompt_embeds}
|
|
|
|
|
|
class AAAUnit_NoiseInitializer(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("height", "width", "seed", "rand_device"),
|
|
output_params=("noise",),
|
|
)
|
|
|
|
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
|
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
|
return {"noise": noise}
|
|
|
|
|
|
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_image", "noise"),
|
|
output_params=("latents", "input_latents"),
|
|
onload_model_names=("vae",)
|
|
)
|
|
|
|
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
|
if input_image is None:
|
|
return {"latents": noise, "input_latents": None}
|
|
pipe.load_models_to_device(['vae'])
|
|
image = pipe.preprocess_image(input_image)
|
|
input_latents = pipe.vae.encode(image)
|
|
if pipe.scheduler.training:
|
|
return {"latents": noise, "input_latents": input_latents}
|
|
else:
|
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
|
return {"latents": latents, "input_latents": input_latents}
|
|
|
|
|
|
def model_fn_aaa(
|
|
dit: AAADiT,
|
|
latents=None,
|
|
prompt_embeds=None,
|
|
timestep=None,
|
|
use_gradient_checkpointing=False,
|
|
use_gradient_checkpointing_offload=False,
|
|
**kwargs,
|
|
):
|
|
model_output = dit(
|
|
latents,
|
|
prompt_embeds,
|
|
timestep,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
)
|
|
return model_output
|
|
|
|
|
|
class AAATrainingModule(DiffusionTrainingModule):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.pipe = AAAImagePipeline.from_pretrained(
|
|
torch_dtype=torch.bfloat16,
|
|
device=device,
|
|
model_configs=[
|
|
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
|
],
|
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
|
)
|
|
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
|
self.pipe.freeze_except(["dit"])
|
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
|
|
|
def forward(self, data):
|
|
inputs_posi = {"prompt": data["prompt"]}
|
|
inputs_nega = {"negative_prompt": ""}
|
|
inputs_shared = {
|
|
"input_image": data["image"],
|
|
"height": data["image"].size[1],
|
|
"width": data["image"].size[0],
|
|
"cfg_scale": 1,
|
|
"use_gradient_checkpointing": False,
|
|
"use_gradient_checkpointing_offload": False,
|
|
}
|
|
for unit in self.pipe.units:
|
|
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
|
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
|
return loss
|
|
|
|
|
|
if __name__ == "__main__":
|
|
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
|
dataset = UnifiedDataset(
|
|
base_path="data/images",
|
|
metadata_path="data/metadata_merged.csv",
|
|
max_data_items=10000000,
|
|
data_file_keys=("image",),
|
|
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
|
)
|
|
model = AAATrainingModule(device=accelerator.device)
|
|
model_logger = ModelLogger(
|
|
"models/AAA/v1",
|
|
remove_prefix_in_ckpt="pipe.dit.",
|
|
)
|
|
launch_training_task(
|
|
accelerator, dataset, model, model_logger,
|
|
learning_rate=2e-4,
|
|
num_workers=4,
|
|
save_steps=50000,
|
|
num_epochs=999999,
|
|
) |