import torch, accelerate from PIL import Image from typing import Union from tqdm import tqdm from einops import rearrange, repeat from transformers import AutoProcessor, AutoTokenizer from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit from diffsynth.models.general_modules import TimestepEmbeddings from diffsynth.models.z_image_text_encoder import ZImageTextEncoder from diffsynth.models.flux2_vae import Flux2VAE class AAAPositionalEmbedding(torch.nn.Module): def __init__(self, height=16, width=16, dim=1024): super().__init__() self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width))) self.text_emb = torch.nn.Parameter(torch.randn((dim,))) def forward(self, image, text): height, width = image.shape[-2:] image_emb = self.image_emb.to(device=image.device, dtype=image.dtype) image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear") image_emb = rearrange(image_emb, "B C H W -> B (H W) C") text_emb = self.text_emb.to(device=text.device, dtype=text.dtype) text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1]) emb = torch.concat([image_emb, text_emb], dim=1) return emb class AAABlock(torch.nn.Module): def __init__(self, dim=1024, num_heads=32): super().__init__() self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False) self.to_q = torch.nn.Linear(dim, dim) self.to_k = torch.nn.Linear(dim, dim) self.to_v = torch.nn.Linear(dim, dim) self.to_out = torch.nn.Linear(dim, dim) self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False) self.ff = torch.nn.Sequential( torch.nn.Linear(dim, dim*3), torch.nn.SiLU(), torch.nn.Linear(dim*3, dim), ) self.to_gate = torch.nn.Linear(dim, dim * 2) self.num_heads = num_heads def attention(self, emb, pos_emb): emb = self.norm_attn(emb + pos_emb) q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb) emb = attention_forward( q, k, v, q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)", dims={"n": self.num_heads}, ) emb = self.to_out(emb) return emb def feed_forward(self, emb, pos_emb): emb = self.norm_mlp(emb + pos_emb) emb = self.ff(emb) return emb def forward(self, emb, pos_emb, t_emb): gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1) emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn) emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp) return emb class AAADiT(torch.nn.Module): def __init__(self, dim=1024): super().__init__() self.pos_embedder = AAAPositionalEmbedding(dim=dim) self.timestep_embedder = TimestepEmbeddings(256, dim) self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim)) self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim)) self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)]) self.proj_out = torch.nn.Linear(dim, 128) def forward( self, latents, prompt_embeds, timestep, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, ): pos_emb = self.pos_embedder(latents, prompt_embeds) t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1) image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C")) text = self.text_embedder(prompt_embeds) emb = torch.concat([image, text], dim=1) for block_id, block in enumerate(self.blocks): emb = gradient_checkpoint_forward( block, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, emb=emb, pos_emb=pos_emb, t_emb=t_emb, ) emb = emb[:, :latents.shape[-1] * latents.shape[-2]] emb = self.proj_out(emb) emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1]) return emb class AAAImagePipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.bfloat16): super().__init__( device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, ) self.scheduler = FlowMatchScheduler("FLUX.2") self.text_encoder: ZImageTextEncoder = None self.dit: AAADiT = None self.vae: Flux2VAE = None self.tokenizer: AutoProcessor = None self.in_iteration_models = ("dit",) self.units = [ AAAUnit_PromptEmbedder(), AAAUnit_NoiseInitializer(), AAAUnit_InputImageEmbedder(), ] self.model_fn = model_fn_aaa @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, device: Union[str, torch.device] = "cuda", model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = None, vram_limit: float = None, ): # Initialize pipeline pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype) model_pool = pipe.download_and_load_models(model_configs, vram_limit) # Fetch models pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder") pipe.dit = model_pool.fetch_model("aaa_dit") pipe.vae = model_pool.fetch_model("flux2_vae") if tokenizer_config is not None: tokenizer_config.download_if_necessary() pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state() return pipe @torch.no_grad() def __call__( self, # Prompt prompt: str, negative_prompt: str = "", cfg_scale: float = 1.0, # Image input_image: Image.Image = None, denoising_strength: float = 1.0, # Shape height: int = 1024, width: int = 1024, # Randomness seed: int = None, rand_device: str = "cpu", # Steps num_inference_steps: int = 30, # Progress bar progress_bar_cmd = tqdm, ): self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) # Parameters inputs_posi = {"prompt": prompt} inputs_nega = {"negative_prompt": negative_prompt} inputs_shared = { "cfg_scale": cfg_scale, "input_image": input_image, "denoising_strength": denoising_strength, "height": height, "width": width, "seed": seed, "rand_device": rand_device, "num_inference_steps": num_inference_steps, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) noise_pred = self.cfg_guided_model_fn( self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **models, timestep=timestep, progress_id=progress_id ) inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared) # Decode self.load_models_to_device(['vae']) image = self.vae.decode(inputs_shared["latents"]) image = self.vae_output_to_image(image) self.load_models_to_device([]) return image class AAAUnit_PromptEmbedder(PipelineUnit): def __init__(self): super().__init__( seperate_cfg=True, input_params_posi={"prompt": "prompt"}, input_params_nega={"prompt": "negative_prompt"}, output_params=("prompt_embeds",), onload_model_names=("text_encoder",) ) self.hidden_states_layers = (-1,) def process(self, pipe: AAAImagePipeline, prompt): pipe.load_models_to_device(self.onload_model_names) text = pipe.tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True, enable_thinking=False, ) inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device) output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False) prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1) return {"prompt_embeds": prompt_embeds} class AAAUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( input_params=("height", "width", "seed", "rand_device"), output_params=("noise",), ) def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device): noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype) return {"noise": noise} class AAAUnit_InputImageEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "noise"), output_params=("latents", "input_latents"), onload_model_names=("vae",) ) def process(self, pipe: AAAImagePipeline, input_image, noise): if input_image is None: return {"latents": noise, "input_latents": None} pipe.load_models_to_device(['vae']) image = pipe.preprocess_image(input_image) input_latents = pipe.vae.encode(image) if pipe.scheduler.training: return {"latents": noise, "input_latents": input_latents} else: latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0]) return {"latents": latents, "input_latents": input_latents} def model_fn_aaa( dit: AAADiT, latents=None, prompt_embeds=None, timestep=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, ): model_output = dit( latents, prompt_embeds, timestep, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) return model_output class AAATrainingModule(DiffusionTrainingModule): def __init__(self, device): super().__init__() self.pipe = AAAImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device=device, model_configs=[ ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), ], tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern=""), ) self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device) self.pipe.freeze_except(["dit"]) self.pipe.scheduler.set_timesteps(1000, training=True) def forward(self, data): inputs_posi = {"prompt": data["prompt"]} inputs_nega = {"negative_prompt": ""} inputs_shared = { "input_image": data["image"], "height": data["image"].size[1], "width": data["image"].size[0], "cfg_scale": 1, "use_gradient_checkpointing": False, "use_gradient_checkpointing_offload": False, } for unit in self.pipe.units: inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi) return loss if __name__ == "__main__": accelerator = accelerate.Accelerator(gradient_accumulation_steps=1) dataset = UnifiedDataset( base_path="data/images", metadata_path="data/metadata_merged.csv", max_data_items=10000000, data_file_keys=("image",), main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256) ) model = AAATrainingModule(device=accelerator.device) model_logger = ModelLogger( "models/AAA/v1", remove_prefix_in_ckpt="pipe.dit.", ) launch_training_task( accelerator, dataset, model, model_logger, learning_rate=2e-4, num_workers=4, save_steps=50000, num_epochs=999999, )