mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
zimagei2l
This commit is contained in:
@@ -513,6 +513,26 @@ z_image_series = [
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||
"extra_kwargs": {"use_conv_attention": False},
|
||||
},
|
||||
{
|
||||
# Example: ???
|
||||
"model_hash": "4f04fa4db33673882c675f426bf42602",
|
||||
"model_name": "z_image_image2lora_style",
|
||||
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||
},
|
||||
{
|
||||
# Example: ???
|
||||
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
|
||||
"model_name": "z_image_image2lora_style",
|
||||
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||
"extra_kwargs": {"compress_dim": 128},
|
||||
},
|
||||
{
|
||||
# Example: ???
|
||||
"model_hash": "cd7427f65cd4cc8092c00c373e2e0a23",
|
||||
"model_name": "z_image_image2lora_style",
|
||||
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||
"extra_kwargs": {"compress_dim": 256},
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||
|
||||
112
diffsynth/models/z_image_image2lora.py
Normal file
112
diffsynth/models/z_image_image2lora.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP
|
||||
|
||||
|
||||
class LoRATrainerBlock(torch.nn.Module):
|
||||
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
self.lora_patterns = lora_patterns
|
||||
self.block_id = block_id
|
||||
self.layers = []
|
||||
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
|
||||
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
|
||||
self.layers = torch.nn.ModuleList(self.layers)
|
||||
if use_residual:
|
||||
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
|
||||
else:
|
||||
self.proj_residual = None
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
if self.proj_residual is not None: residual = self.proj_residual(residual)
|
||||
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
|
||||
name = lora_pattern[0]
|
||||
lora_a, lora_b = layer(x, residual=residual)
|
||||
lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
|
||||
lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
|
||||
return lora
|
||||
|
||||
|
||||
class ZImageImage2LoRAComponent(torch.nn.Module):
|
||||
def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||
super().__init__()
|
||||
self.lora_patterns = lora_patterns
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks = []
|
||||
for lora_patterns in self.lora_patterns:
|
||||
for block_id in range(self.num_blocks):
|
||||
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix))
|
||||
self.blocks = torch.nn.ModuleList(self.blocks)
|
||||
self.residual_scale = 0.05
|
||||
self.use_residual = use_residual
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
if residual is not None:
|
||||
if self.use_residual:
|
||||
residual = residual * self.residual_scale
|
||||
else:
|
||||
residual = None
|
||||
lora = {}
|
||||
for block in self.blocks:
|
||||
lora.update(block(x, residual))
|
||||
return lora
|
||||
|
||||
|
||||
class ZImageImage2LoRAModel(torch.nn.Module):
|
||||
def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||
super().__init__()
|
||||
lora_patterns = [
|
||||
[
|
||||
("attention.to_q", 3840, 3840),
|
||||
("attention.to_k", 3840, 3840),
|
||||
("attention.to_v", 3840, 3840),
|
||||
("attention.to_out.0", 3840, 3840),
|
||||
],
|
||||
[
|
||||
("feed_forward.w1", 3840, 10240),
|
||||
("feed_forward.w2", 10240, 3840),
|
||||
("feed_forward.w3", 3840, 10240),
|
||||
],
|
||||
]
|
||||
config = {
|
||||
"lora_patterns": lora_patterns,
|
||||
"use_residual": use_residual,
|
||||
"compress_dim": compress_dim,
|
||||
"rank": rank,
|
||||
"residual_length": residual_length,
|
||||
"residual_mid_dim": residual_mid_dim,
|
||||
}
|
||||
self.layers_lora = ZImageImage2LoRAComponent(
|
||||
prefix="layers",
|
||||
num_blocks=30,
|
||||
**config,
|
||||
)
|
||||
self.context_refiner_lora = ZImageImage2LoRAComponent(
|
||||
prefix="context_refiner",
|
||||
num_blocks=2,
|
||||
**config,
|
||||
)
|
||||
self.noise_refiner_lora = ZImageImage2LoRAComponent(
|
||||
prefix="noise_refiner",
|
||||
num_blocks=2,
|
||||
**config,
|
||||
)
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
lora.update(self.layers_lora(x, residual=residual))
|
||||
lora.update(self.context_refiner_lora(x, residual=residual))
|
||||
lora.update(self.noise_refiner_lora(x, residual=residual))
|
||||
return lora
|
||||
|
||||
def initialize_weights(self):
|
||||
state_dict = self.state_dict()
|
||||
for name in state_dict:
|
||||
if ".proj_a." in name:
|
||||
state_dict[name] = state_dict[name] * 0.3
|
||||
elif ".proj_b.proj_out." in name:
|
||||
state_dict[name] = state_dict[name] * 0
|
||||
elif ".proj_residual.proj_out." in name:
|
||||
state_dict[name] = state_dict[name] * 0.3
|
||||
self.load_state_dict(state_dict)
|
||||
@@ -9,11 +9,15 @@ from typing import Union, List, Optional, Tuple
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
from ..utils.lora import merge_lora
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
from ..models.z_image_dit import ZImageDiT
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
|
||||
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
|
||||
from ..models.z_image_image2lora import ZImageImage2LoRAModel
|
||||
|
||||
|
||||
class ZImagePipeline(BasePipeline):
|
||||
@@ -28,6 +32,9 @@ class ZImagePipeline(BasePipeline):
|
||||
self.dit: ZImageDiT = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.siglip2_image_encoder: Siglip2ImageEncoder = None
|
||||
self.dinov3_image_encoder: DINOv3ImageEncoder = None
|
||||
self.image2lora_style: ZImageImage2LoRAModel = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
@@ -56,6 +63,9 @@ class ZImagePipeline(BasePipeline):
|
||||
pipe.dit = model_pool.fetch_model("z_image_dit")
|
||||
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
||||
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
|
||||
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
|
||||
pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
@@ -83,6 +93,8 @@ class ZImagePipeline(BasePipeline):
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 8,
|
||||
# Image to LoRA
|
||||
image2lora_images: List[Image.Image] = None,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
@@ -102,6 +114,7 @@ class ZImagePipeline(BasePipeline):
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"image2lora_images": image2lora_images,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -234,6 +247,131 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
class ZImageUnit_Image2LoRAEncode(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("image2lora_images",),
|
||||
output_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"),
|
||||
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",),
|
||||
)
|
||||
from ..core.data.operators import ImageCropAndResize
|
||||
self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8)
|
||||
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
|
||||
|
||||
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||
bool_mask = mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
return split_result
|
||||
|
||||
def encode_prompt_edit(self, pipe: ZImagePipeline, prompt, edit_image):
|
||||
prompt = [prompt]
|
||||
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
drop_idx = 64
|
||||
txt = [template.format(e) for e in prompt]
|
||||
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
||||
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
||||
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
||||
prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return prompt_embeds.view(1, -1)
|
||||
|
||||
def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
pipe.load_models_to_device(["siglip2_image_encoder"])
|
||||
embs = []
|
||||
for image in images:
|
||||
image = self.processor_highres(image)
|
||||
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
|
||||
embs = torch.stack(embs)
|
||||
return embs
|
||||
|
||||
def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
pipe.load_models_to_device(["dinov3_image_encoder"])
|
||||
embs = []
|
||||
for image in images:
|
||||
image = self.processor_highres(image)
|
||||
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
|
||||
embs = torch.stack(embs)
|
||||
return embs
|
||||
|
||||
def encode_images_using_qwenvl(self, pipe: ZImagePipeline, images: list[Image.Image], highres=False):
|
||||
pipe.load_models_to_device(["text_encoder"])
|
||||
embs = []
|
||||
for image in images:
|
||||
image = self.processor_highres(image) if highres else self.processor_lowres(image)
|
||||
embs.append(self.encode_prompt_edit(pipe, prompt="", edit_image=image))
|
||||
embs = torch.stack(embs)
|
||||
return embs
|
||||
|
||||
def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
if images is None:
|
||||
return {}
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
|
||||
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
|
||||
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
|
||||
residual = None
|
||||
residual_highres = None
|
||||
return x, residual, residual_highres
|
||||
|
||||
def process(self, pipe: ZImagePipeline, image2lora_images):
|
||||
if image2lora_images is None:
|
||||
return {}
|
||||
x, residual, residual_highres = self.encode_images(pipe, image2lora_images)
|
||||
return {"image2lora_x": x, "image2lora_residual": residual, "image2lora_residual_highres": residual_highres}
|
||||
|
||||
|
||||
class ZImageUnit_Image2LoRADecode(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"),
|
||||
output_params=("lora",),
|
||||
onload_model_names=("image2lora_style",),
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres):
|
||||
if image2lora_x is None:
|
||||
return {}
|
||||
loras = []
|
||||
if pipe.image2lora_style is not None:
|
||||
pipe.load_models_to_device(["image2lora_style"])
|
||||
for x in image2lora_x:
|
||||
loras.append(pipe.image2lora_style(x=x, residual=None))
|
||||
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
|
||||
return {"lora": lora}
|
||||
|
||||
|
||||
class ZImageUnit_Image2LoRATraining(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("lora",),
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, lora):
|
||||
if lora is None:
|
||||
return {}
|
||||
pipe.clear_lora()
|
||||
pipe.load_lora(pipe.dit, state_dict=lora)
|
||||
return {}
|
||||
|
||||
|
||||
class ZImageUnit_DelUnusedParams(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(take_over=True)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if not pipe.scheduler.training:
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
if "input_image" in inputs_shared: inputs_shared.pop("input_image")
|
||||
if "image2lora_images" in inputs_shared: inputs_shared.pop("image2lora_images")
|
||||
if "noise" in inputs_shared: inputs_shared.pop("noise")
|
||||
if "latents" in inputs_shared: inputs_shared.pop("latents")
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
def model_fn_z_image(
|
||||
dit: ZImageDiT,
|
||||
latents=None,
|
||||
|
||||
21
prepare.py
Normal file
21
prepare.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from diffsynth import load_state_dict, skip_model_initialization
|
||||
from diffsynth.models.z_image_image2lora import ZImageImage2LoRAModel
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
|
||||
import torch, os
|
||||
from PIL import Image
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
model = ZImageImage2LoRAModel(compress_dim=256).to("cuda").to(torch.bfloat16)
|
||||
model.initialize_weights()
|
||||
os.makedirs("models/train/Z-Image-i2L_v12", exist_ok=True)
|
||||
save_file(model.state_dict(), "models/train/Z-Image-i2L_v12/model.safetensors")
|
||||
|
||||
# check loading
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig("models/train/Z-Image-i2L_v12/model.safetensors"),
|
||||
],
|
||||
)
|
||||
14
run.sh
Normal file
14
run.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
accelerate launch train.py \
|
||||
--dataset_base_path "" \
|
||||
--dataset_metadata_path data/metadata_sampled_110w.csv \
|
||||
--model_paths "models/train/Z-Image-i2L_v12/model.safetensors" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 100 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 10000 \
|
||||
--remove_prefix_in_ckpt "pipe.image2lora_style." \
|
||||
--output_path "./models/train/Z-Image-i2L_v13" \
|
||||
--trainable_models "image2lora_style" \
|
||||
--dataset_num_workers 2 \
|
||||
--use_gradient_checkpointing \
|
||||
--save_steps 1000
|
||||
58
test.py
Normal file
58
test.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from diffsynth.pipelines.z_image import (
|
||||
ZImagePipeline, ModelConfig,
|
||||
ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
|
||||
)
|
||||
from modelscope import snapshot_download
|
||||
from safetensors.torch import save_file
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cuda",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
# Load models
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Base-1211_Temp", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config),
|
||||
ModelConfig("models/train/Z-Image-i2L_v13/step-58000.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=80,
|
||||
)
|
||||
|
||||
# Load images
|
||||
snapshot_download(
|
||||
model_id="DiffSynth-Studio/Qwen-Image-i2L",
|
||||
allow_file_pattern="assets/style/*",
|
||||
local_dir="data/examples"
|
||||
)
|
||||
for style_id in range(1, 5):
|
||||
images = [Image.open(f"data/examples/assets/style/{style_id}/{i}.jpg") for i in range(4)]
|
||||
|
||||
with torch.no_grad():
|
||||
embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
|
||||
lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
|
||||
|
||||
prompt = "a cat"
|
||||
pipe.clear_lora()
|
||||
pipe.load_lora(pipe.dit, state_dict=lora, alpha=1)
|
||||
image = pipe(prompt=prompt, seed=123, cfg_scale=4, num_inference_steps=50)
|
||||
image.save(f"image_lora_{style_id}.jpg")
|
||||
|
||||
pipe.clear_lora()
|
||||
image = pipe(prompt=prompt, seed=123, cfg_scale=4, num_inference_steps=50)
|
||||
image.save("image_base.jpg")
|
||||
181
train.py
Normal file
181
train.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import torch, os, argparse, accelerate, copy
|
||||
from diffsynth.core import UnifiedDataset
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
from diffsynth.pipelines.z_image import ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode, ZImageUnit_Image2LoRATraining
|
||||
from diffsynth.diffusion import *
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class ZImageTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
tokenizer_path=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
preset_lora_path=None, preset_lora_model=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
fp8_models=None,
|
||||
offload_models=None,
|
||||
device="cpu",
|
||||
task="sft",
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": device,
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": device,
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": device,
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": device,
|
||||
}
|
||||
self.pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Base-1211_Temp", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"),
|
||||
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"),
|
||||
ModelConfig(model_paths),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
self.pipe.vram_management_enabled = False
|
||||
self.pipe.units = self.pipe.units + [
|
||||
ZImageUnit_Image2LoRAEncode(),
|
||||
ZImageUnit_Image2LoRADecode(),
|
||||
ZImageUnit_Image2LoRATraining(),
|
||||
]
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
|
||||
# Training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||
preset_lora_path, preset_lora_model,
|
||||
task=task,
|
||||
)
|
||||
|
||||
# Other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
self.fp8_models = fp8_models
|
||||
self.task = task
|
||||
self.task_to_loss = {
|
||||
"sft:data_process": lambda pipe, *args: args,
|
||||
"direct_distill:data_process": lambda pipe, *args: args,
|
||||
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
}
|
||||
if task == "trajectory_imitation":
|
||||
# This is an experimental feature.
|
||||
# We may remove it in the future.
|
||||
self.loss_fn = TrajectoryImitationLoss()
|
||||
self.task_to_loss["trajectory_imitation"] = self.loss_fn
|
||||
self.pipe_teacher = copy.deepcopy(self.pipe)
|
||||
self.pipe_teacher.requires_grad_(False)
|
||||
|
||||
def get_pipeline_inputs(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
# Assume you are using this pipeline for inference,
|
||||
# please fill in the input parameters.
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
"image2lora_images": data["image"],
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"cfg_scale": 1,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
}
|
||||
if self.task == "trajectory_imitation":
|
||||
inputs_shared["cfg_scale"] = 2
|
||||
inputs_shared["teacher"] = self.pipe_teacher
|
||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
for unit in self.pipe.units:
|
||||
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||
return loss
|
||||
|
||||
|
||||
def z_image_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser = add_general_config(parser)
|
||||
parser = add_image_size_config(parser)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = z_image_parser()
|
||||
args = parser.parse_args()
|
||||
accelerator = accelerate.Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
height_division_factor=16,
|
||||
width_division_factor=16,
|
||||
)
|
||||
)
|
||||
model = ZImageTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
preset_lora_path=args.preset_lora_path,
|
||||
preset_lora_model=args.preset_lora_model,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
fp8_models=args.fp8_models,
|
||||
offload_models=args.offload_models,
|
||||
task=args.task,
|
||||
device=accelerator.device,
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
)
|
||||
launcher_map = {
|
||||
"sft:data_process": launch_data_process_task,
|
||||
"direct_distill:data_process": launch_data_process_task,
|
||||
"sft": launch_training_task,
|
||||
"sft:train": launch_training_task,
|
||||
"direct_distill": launch_training_task,
|
||||
"direct_distill:train": launch_training_task,
|
||||
"trajectory_imitation": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||
Reference in New Issue
Block a user