From 7af51b5e108f35a68ec6528d5b0649d4bffd00f4 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 23 Dec 2025 17:47:42 +0800 Subject: [PATCH] zimagei2l --- diffsynth/configs/model_configs.py | 20 +++ diffsynth/models/z_image_image2lora.py | 112 +++++++++++++++ diffsynth/pipelines/z_image.py | 138 +++++++++++++++++++ prepare.py | 21 +++ run.sh | 14 ++ test.py | 58 ++++++++ train.py | 181 +++++++++++++++++++++++++ 7 files changed, 544 insertions(+) create mode 100644 diffsynth/models/z_image_image2lora.py create mode 100644 prepare.py create mode 100644 run.sh create mode 100644 test.py create mode 100644 train.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index dca078a..563015f 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -513,6 +513,26 @@ z_image_series = [ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers", "extra_kwargs": {"use_conv_attention": False}, }, + { + # Example: ??? + "model_hash": "4f04fa4db33673882c675f426bf42602", + "model_name": "z_image_image2lora_style", + "model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel", + }, + { + # Example: ??? + "model_hash": "9510cb8cd1dd34ee0e4f111c24905510", + "model_name": "z_image_image2lora_style", + "model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel", + "extra_kwargs": {"compress_dim": 128}, + }, + { + # Example: ??? + "model_hash": "cd7427f65cd4cc8092c00c373e2e0a23", + "model_name": "z_image_image2lora_style", + "model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel", + "extra_kwargs": {"compress_dim": 256}, + }, ] MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series diff --git a/diffsynth/models/z_image_image2lora.py b/diffsynth/models/z_image_image2lora.py new file mode 100644 index 0000000..70d591b --- /dev/null +++ b/diffsynth/models/z_image_image2lora.py @@ -0,0 +1,112 @@ +import torch +from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP + + +class LoRATrainerBlock(torch.nn.Module): + def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"): + super().__init__() + self.prefix = prefix + self.lora_patterns = lora_patterns + self.block_id = block_id + self.layers = [] + for name, lora_a_dim, lora_b_dim in self.lora_patterns: + self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank)) + self.layers = torch.nn.ModuleList(self.layers) + if use_residual: + self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim) + else: + self.proj_residual = None + + def forward(self, x, residual=None): + lora = {} + if self.proj_residual is not None: residual = self.proj_residual(residual) + for lora_pattern, layer in zip(self.lora_patterns, self.layers): + name = lora_pattern[0] + lora_a, lora_b = layer(x, residual=residual) + lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a + lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b + return lora + + +class ZImageImage2LoRAComponent(torch.nn.Module): + def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + self.lora_patterns = lora_patterns + self.num_blocks = num_blocks + self.blocks = [] + for lora_patterns in self.lora_patterns: + for block_id in range(self.num_blocks): + self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix)) + self.blocks = torch.nn.ModuleList(self.blocks) + self.residual_scale = 0.05 + self.use_residual = use_residual + + def forward(self, x, residual=None): + if residual is not None: + if self.use_residual: + residual = residual * self.residual_scale + else: + residual = None + lora = {} + for block in self.blocks: + lora.update(block(x, residual)) + return lora + + +class ZImageImage2LoRAModel(torch.nn.Module): + def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024): + super().__init__() + lora_patterns = [ + [ + ("attention.to_q", 3840, 3840), + ("attention.to_k", 3840, 3840), + ("attention.to_v", 3840, 3840), + ("attention.to_out.0", 3840, 3840), + ], + [ + ("feed_forward.w1", 3840, 10240), + ("feed_forward.w2", 10240, 3840), + ("feed_forward.w3", 3840, 10240), + ], + ] + config = { + "lora_patterns": lora_patterns, + "use_residual": use_residual, + "compress_dim": compress_dim, + "rank": rank, + "residual_length": residual_length, + "residual_mid_dim": residual_mid_dim, + } + self.layers_lora = ZImageImage2LoRAComponent( + prefix="layers", + num_blocks=30, + **config, + ) + self.context_refiner_lora = ZImageImage2LoRAComponent( + prefix="context_refiner", + num_blocks=2, + **config, + ) + self.noise_refiner_lora = ZImageImage2LoRAComponent( + prefix="noise_refiner", + num_blocks=2, + **config, + ) + + def forward(self, x, residual=None): + lora = {} + lora.update(self.layers_lora(x, residual=residual)) + lora.update(self.context_refiner_lora(x, residual=residual)) + lora.update(self.noise_refiner_lora(x, residual=residual)) + return lora + + def initialize_weights(self): + state_dict = self.state_dict() + for name in state_dict: + if ".proj_a." in name: + state_dict[name] = state_dict[name] * 0.3 + elif ".proj_b.proj_out." in name: + state_dict[name] = state_dict[name] * 0 + elif ".proj_residual.proj_out." in name: + state_dict[name] = state_dict[name] * 0.3 + self.load_state_dict(state_dict) diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index f87254f..e66c5ed 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -9,11 +9,15 @@ from typing import Union, List, Optional, Tuple from ..diffusion import FlowMatchScheduler from ..core import ModelConfig, gradient_checkpoint_forward from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput +from ..utils.lora import merge_lora from transformers import AutoTokenizer from ..models.z_image_text_encoder import ZImageTextEncoder from ..models.z_image_dit import ZImageDiT from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.siglip2_image_encoder import Siglip2ImageEncoder +from ..models.dinov3_image_encoder import DINOv3ImageEncoder +from ..models.z_image_image2lora import ZImageImage2LoRAModel class ZImagePipeline(BasePipeline): @@ -28,6 +32,9 @@ class ZImagePipeline(BasePipeline): self.dit: ZImageDiT = None self.vae_encoder: FluxVAEEncoder = None self.vae_decoder: FluxVAEDecoder = None + self.siglip2_image_encoder: Siglip2ImageEncoder = None + self.dinov3_image_encoder: DINOv3ImageEncoder = None + self.image2lora_style: ZImageImage2LoRAModel = None self.tokenizer: AutoTokenizer = None self.in_iteration_models = ("dit",) self.units = [ @@ -56,6 +63,9 @@ class ZImagePipeline(BasePipeline): pipe.dit = model_pool.fetch_model("z_image_dit") pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") + pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder") + pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder") + pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style") if tokenizer_config is not None: tokenizer_config.download_if_necessary() pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) @@ -83,6 +93,8 @@ class ZImagePipeline(BasePipeline): rand_device: str = "cpu", # Steps num_inference_steps: int = 8, + # Image to LoRA + image2lora_images: List[Image.Image] = None, # Progress bar progress_bar_cmd = tqdm, ): @@ -102,6 +114,7 @@ class ZImagePipeline(BasePipeline): "height": height, "width": width, "seed": seed, "rand_device": rand_device, "num_inference_steps": num_inference_steps, + "image2lora_images": image2lora_images, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -234,6 +247,131 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit): return {"latents": latents, "input_latents": input_latents} +class ZImageUnit_Image2LoRAEncode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_images",), + output_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), + onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",), + ) + from ..core.data.operators import ImageCropAndResize + self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8) + self.processor_highres = ImageCropAndResize(height=1024, width=1024) + + def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + return split_result + + def encode_prompt_edit(self, pipe: ZImagePipeline, prompt, edit_image): + prompt = [prompt] + template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + drop_idx = 64 + txt = [template.format(e) for e in prompt] + model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device) + hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1] + split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device) + return prompt_embeds.view(1, -1) + + def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["siglip2_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]): + pipe.load_models_to_device(["dinov3_image_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) + embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype)) + embs = torch.stack(embs) + return embs + + def encode_images_using_qwenvl(self, pipe: ZImagePipeline, images: list[Image.Image], highres=False): + pipe.load_models_to_device(["text_encoder"]) + embs = [] + for image in images: + image = self.processor_highres(image) if highres else self.processor_lowres(image) + embs.append(self.encode_prompt_edit(pipe, prompt="", edit_image=image)) + embs = torch.stack(embs) + return embs + + def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]): + if images is None: + return {} + if not isinstance(images, list): + images = [images] + embs_siglip2 = self.encode_images_using_siglip2(pipe, images) + embs_dinov3 = self.encode_images_using_dinov3(pipe, images) + x = torch.concat([embs_siglip2, embs_dinov3], dim=-1) + residual = None + residual_highres = None + return x, residual, residual_highres + + def process(self, pipe: ZImagePipeline, image2lora_images): + if image2lora_images is None: + return {} + x, residual, residual_highres = self.encode_images(pipe, image2lora_images) + return {"image2lora_x": x, "image2lora_residual": residual, "image2lora_residual_highres": residual_highres} + + +class ZImageUnit_Image2LoRADecode(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"), + output_params=("lora",), + onload_model_names=("image2lora_style",), + ) + + def process(self, pipe: ZImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres): + if image2lora_x is None: + return {} + loras = [] + if pipe.image2lora_style is not None: + pipe.load_models_to_device(["image2lora_style"]) + for x in image2lora_x: + loras.append(pipe.image2lora_style(x=x, residual=None)) + lora = merge_lora(loras, alpha=1 / len(image2lora_x)) + return {"lora": lora} + + +class ZImageUnit_Image2LoRATraining(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("lora",), + ) + + def process(self, pipe: ZImagePipeline, lora): + if lora is None: + return {} + pipe.clear_lora() + pipe.load_lora(pipe.dit, state_dict=lora) + return {} + + +class ZImageUnit_DelUnusedParams(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + + def process(self, pipe: ZImagePipeline, inputs_shared, inputs_posi, inputs_nega): + if not pipe.scheduler.training: + return inputs_shared, inputs_posi, inputs_nega + if "input_image" in inputs_shared: inputs_shared.pop("input_image") + if "image2lora_images" in inputs_shared: inputs_shared.pop("image2lora_images") + if "noise" in inputs_shared: inputs_shared.pop("noise") + if "latents" in inputs_shared: inputs_shared.pop("latents") + return inputs_shared, inputs_posi, inputs_nega + def model_fn_z_image( dit: ZImageDiT, latents=None, diff --git a/prepare.py b/prepare.py new file mode 100644 index 0000000..725160e --- /dev/null +++ b/prepare.py @@ -0,0 +1,21 @@ +from diffsynth import load_state_dict, skip_model_initialization +from diffsynth.models.z_image_image2lora import ZImageImage2LoRAModel +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode +import torch, os +from PIL import Image +from safetensors.torch import save_file + + +model = ZImageImage2LoRAModel(compress_dim=256).to("cuda").to(torch.bfloat16) +model.initialize_weights() +os.makedirs("models/train/Z-Image-i2L_v12", exist_ok=True) +save_file(model.state_dict(), "models/train/Z-Image-i2L_v12/model.safetensors") + +# check loading +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig("models/train/Z-Image-i2L_v12/model.safetensors"), + ], +) \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..37098cc --- /dev/null +++ b/run.sh @@ -0,0 +1,14 @@ +accelerate launch train.py \ + --dataset_base_path "" \ + --dataset_metadata_path data/metadata_sampled_110w.csv \ + --model_paths "models/train/Z-Image-i2L_v12/model.safetensors" \ + --max_pixels 1048576 \ + --dataset_repeat 100 \ + --learning_rate 1e-5 \ + --num_epochs 10000 \ + --remove_prefix_in_ckpt "pipe.image2lora_style." \ + --output_path "./models/train/Z-Image-i2L_v13" \ + --trainable_models "image2lora_style" \ + --dataset_num_workers 2 \ + --use_gradient_checkpointing \ + --save_steps 1000 diff --git a/test.py b/test.py new file mode 100644 index 0000000..119aa94 --- /dev/null +++ b/test.py @@ -0,0 +1,58 @@ +from diffsynth.pipelines.z_image import ( + ZImagePipeline, ModelConfig, + ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode +) +from modelscope import snapshot_download +from safetensors.torch import save_file +import torch +from PIL import Image + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cuda", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} + +# Load models +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Base-1211_Temp", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config), + ModelConfig("models/train/Z-Image-i2L_v13/step-58000.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + vram_limit=80, +) + +# Load images +snapshot_download( + model_id="DiffSynth-Studio/Qwen-Image-i2L", + allow_file_pattern="assets/style/*", + local_dir="data/examples" +) +for style_id in range(1, 5): + images = [Image.open(f"data/examples/assets/style/{style_id}/{i}.jpg") for i in range(4)] + + with torch.no_grad(): + embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images) + lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"] + + prompt = "a cat" + pipe.clear_lora() + pipe.load_lora(pipe.dit, state_dict=lora, alpha=1) + image = pipe(prompt=prompt, seed=123, cfg_scale=4, num_inference_steps=50) + image.save(f"image_lora_{style_id}.jpg") + +pipe.clear_lora() +image = pipe(prompt=prompt, seed=123, cfg_scale=4, num_inference_steps=50) +image.save("image_base.jpg") diff --git a/train.py b/train.py new file mode 100644 index 0000000..610a250 --- /dev/null +++ b/train.py @@ -0,0 +1,181 @@ +import torch, os, argparse, accelerate, copy +from diffsynth.core import UnifiedDataset +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +from diffsynth.pipelines.z_image import ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode, ZImageUnit_Image2LoRATraining +from diffsynth.diffusion import * +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class ZImageTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + preset_lora_path=None, preset_lora_model=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + fp8_models=None, + offload_models=None, + device="cpu", + task="sft", + ): + super().__init__() + # Load models + vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": device, + "onload_dtype": torch.bfloat16, + "onload_device": device, + "preparing_dtype": torch.bfloat16, + "preparing_device": device, + "computation_dtype": torch.bfloat16, + "computation_device": device, + } + self.pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device=device, + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Base-1211_Temp", origin_file_pattern="transformer/*.safetensors", **vram_config), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"), + ModelConfig(model_paths), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), + ) + self.pipe.vram_management_enabled = False + self.pipe.units = self.pipe.units + [ + ZImageUnit_Image2LoRAEncode(), + ZImageUnit_Image2LoRADecode(), + ZImageUnit_Image2LoRATraining(), + ] + self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) + + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, + preset_lora_path, preset_lora_model, + task=task, + ) + + # Other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.fp8_models = fp8_models + self.task = task + self.task_to_loss = { + "sft:data_process": lambda pipe, *args: args, + "direct_distill:data_process": lambda pipe, *args: args, + "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), + } + if task == "trajectory_imitation": + # This is an experimental feature. + # We may remove it in the future. + self.loss_fn = TrajectoryImitationLoss() + self.task_to_loss["trajectory_imitation"] = self.loss_fn + self.pipe_teacher = copy.deepcopy(self.pipe) + self.pipe_teacher.requires_grad_(False) + + def get_pipeline_inputs(self, data): + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + "image2lora_images": data["image"], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + } + if self.task == "trajectory_imitation": + inputs_shared["cfg_scale"] = 2 + inputs_shared["teacher"] = self.pipe_teacher + inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) + return inputs_shared, inputs_posi, inputs_nega + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.get_pipeline_inputs(data) + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + for unit in self.pipe.units: + inputs = self.pipe.unit_runner(unit, self.pipe, *inputs) + loss = self.task_to_loss[self.task](self.pipe, *inputs) + return loss + + +def z_image_parser(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser = add_general_config(parser) + parser = add_image_size_config(parser) + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") + return parser + + +if __name__ == "__main__": + parser = z_image_parser() + args = parser.parse_args() + accelerator = accelerate.Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], + ) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = ZImageTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + preset_lora_path=args.preset_lora_path, + preset_lora_model=args.preset_lora_model, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + fp8_models=args.fp8_models, + offload_models=args.offload_models, + task=args.task, + device=accelerator.device, + ) + model_logger = ModelLogger( + args.output_path, + remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + ) + launcher_map = { + "sft:data_process": launch_data_process_task, + "direct_distill:data_process": launch_data_process_task, + "sft": launch_training_task, + "sft:train": launch_training_task, + "direct_distill": launch_training_task, + "direct_distill:train": launch_training_task, + "trajectory_imitation": launch_training_task, + } + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)