mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7af51b5e10 |
@@ -513,6 +513,26 @@ z_image_series = [
|
|||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
||||||
"extra_kwargs": {"use_conv_attention": False},
|
"extra_kwargs": {"use_conv_attention": False},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# Example: ???
|
||||||
|
"model_hash": "4f04fa4db33673882c675f426bf42602",
|
||||||
|
"model_name": "z_image_image2lora_style",
|
||||||
|
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ???
|
||||||
|
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
|
||||||
|
"model_name": "z_image_image2lora_style",
|
||||||
|
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"compress_dim": 128},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ???
|
||||||
|
"model_hash": "cd7427f65cd4cc8092c00c373e2e0a23",
|
||||||
|
"model_name": "z_image_image2lora_style",
|
||||||
|
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||||
|
"extra_kwargs": {"compress_dim": 256},
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||||
|
|||||||
112
diffsynth/models/z_image_image2lora.py
Normal file
112
diffsynth/models/z_image_image2lora.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP
|
||||||
|
|
||||||
|
|
||||||
|
class LoRATrainerBlock(torch.nn.Module):
|
||||||
|
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"):
|
||||||
|
super().__init__()
|
||||||
|
self.prefix = prefix
|
||||||
|
self.lora_patterns = lora_patterns
|
||||||
|
self.block_id = block_id
|
||||||
|
self.layers = []
|
||||||
|
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
|
||||||
|
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
|
||||||
|
self.layers = torch.nn.ModuleList(self.layers)
|
||||||
|
if use_residual:
|
||||||
|
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
|
||||||
|
else:
|
||||||
|
self.proj_residual = None
|
||||||
|
|
||||||
|
def forward(self, x, residual=None):
|
||||||
|
lora = {}
|
||||||
|
if self.proj_residual is not None: residual = self.proj_residual(residual)
|
||||||
|
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
|
||||||
|
name = lora_pattern[0]
|
||||||
|
lora_a, lora_b = layer(x, residual=residual)
|
||||||
|
lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
|
||||||
|
lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
|
||||||
|
return lora
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageImage2LoRAComponent(torch.nn.Module):
|
||||||
|
def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.lora_patterns = lora_patterns
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.blocks = []
|
||||||
|
for lora_patterns in self.lora_patterns:
|
||||||
|
for block_id in range(self.num_blocks):
|
||||||
|
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix))
|
||||||
|
self.blocks = torch.nn.ModuleList(self.blocks)
|
||||||
|
self.residual_scale = 0.05
|
||||||
|
self.use_residual = use_residual
|
||||||
|
|
||||||
|
def forward(self, x, residual=None):
|
||||||
|
if residual is not None:
|
||||||
|
if self.use_residual:
|
||||||
|
residual = residual * self.residual_scale
|
||||||
|
else:
|
||||||
|
residual = None
|
||||||
|
lora = {}
|
||||||
|
for block in self.blocks:
|
||||||
|
lora.update(block(x, residual))
|
||||||
|
return lora
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageImage2LoRAModel(torch.nn.Module):
|
||||||
|
def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
lora_patterns = [
|
||||||
|
[
|
||||||
|
("attention.to_q", 3840, 3840),
|
||||||
|
("attention.to_k", 3840, 3840),
|
||||||
|
("attention.to_v", 3840, 3840),
|
||||||
|
("attention.to_out.0", 3840, 3840),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
("feed_forward.w1", 3840, 10240),
|
||||||
|
("feed_forward.w2", 10240, 3840),
|
||||||
|
("feed_forward.w3", 3840, 10240),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
config = {
|
||||||
|
"lora_patterns": lora_patterns,
|
||||||
|
"use_residual": use_residual,
|
||||||
|
"compress_dim": compress_dim,
|
||||||
|
"rank": rank,
|
||||||
|
"residual_length": residual_length,
|
||||||
|
"residual_mid_dim": residual_mid_dim,
|
||||||
|
}
|
||||||
|
self.layers_lora = ZImageImage2LoRAComponent(
|
||||||
|
prefix="layers",
|
||||||
|
num_blocks=30,
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
self.context_refiner_lora = ZImageImage2LoRAComponent(
|
||||||
|
prefix="context_refiner",
|
||||||
|
num_blocks=2,
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
self.noise_refiner_lora = ZImageImage2LoRAComponent(
|
||||||
|
prefix="noise_refiner",
|
||||||
|
num_blocks=2,
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, residual=None):
|
||||||
|
lora = {}
|
||||||
|
lora.update(self.layers_lora(x, residual=residual))
|
||||||
|
lora.update(self.context_refiner_lora(x, residual=residual))
|
||||||
|
lora.update(self.noise_refiner_lora(x, residual=residual))
|
||||||
|
return lora
|
||||||
|
|
||||||
|
def initialize_weights(self):
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
for name in state_dict:
|
||||||
|
if ".proj_a." in name:
|
||||||
|
state_dict[name] = state_dict[name] * 0.3
|
||||||
|
elif ".proj_b.proj_out." in name:
|
||||||
|
state_dict[name] = state_dict[name] * 0
|
||||||
|
elif ".proj_residual.proj_out." in name:
|
||||||
|
state_dict[name] = state_dict[name] * 0.3
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
@@ -9,11 +9,15 @@ from typing import Union, List, Optional, Tuple
|
|||||||
from ..diffusion import FlowMatchScheduler
|
from ..diffusion import FlowMatchScheduler
|
||||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||||
|
from ..utils.lora import merge_lora
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||||
from ..models.z_image_dit import ZImageDiT
|
from ..models.z_image_dit import ZImageDiT
|
||||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||||
|
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
|
||||||
|
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
|
||||||
|
from ..models.z_image_image2lora import ZImageImage2LoRAModel
|
||||||
|
|
||||||
|
|
||||||
class ZImagePipeline(BasePipeline):
|
class ZImagePipeline(BasePipeline):
|
||||||
@@ -28,6 +32,9 @@ class ZImagePipeline(BasePipeline):
|
|||||||
self.dit: ZImageDiT = None
|
self.dit: ZImageDiT = None
|
||||||
self.vae_encoder: FluxVAEEncoder = None
|
self.vae_encoder: FluxVAEEncoder = None
|
||||||
self.vae_decoder: FluxVAEDecoder = None
|
self.vae_decoder: FluxVAEDecoder = None
|
||||||
|
self.siglip2_image_encoder: Siglip2ImageEncoder = None
|
||||||
|
self.dinov3_image_encoder: DINOv3ImageEncoder = None
|
||||||
|
self.image2lora_style: ZImageImage2LoRAModel = None
|
||||||
self.tokenizer: AutoTokenizer = None
|
self.tokenizer: AutoTokenizer = None
|
||||||
self.in_iteration_models = ("dit",)
|
self.in_iteration_models = ("dit",)
|
||||||
self.units = [
|
self.units = [
|
||||||
@@ -56,6 +63,9 @@ class ZImagePipeline(BasePipeline):
|
|||||||
pipe.dit = model_pool.fetch_model("z_image_dit")
|
pipe.dit = model_pool.fetch_model("z_image_dit")
|
||||||
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
||||||
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||||
|
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
|
||||||
|
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
|
||||||
|
pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style")
|
||||||
if tokenizer_config is not None:
|
if tokenizer_config is not None:
|
||||||
tokenizer_config.download_if_necessary()
|
tokenizer_config.download_if_necessary()
|
||||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
@@ -83,6 +93,8 @@ class ZImagePipeline(BasePipeline):
|
|||||||
rand_device: str = "cpu",
|
rand_device: str = "cpu",
|
||||||
# Steps
|
# Steps
|
||||||
num_inference_steps: int = 8,
|
num_inference_steps: int = 8,
|
||||||
|
# Image to LoRA
|
||||||
|
image2lora_images: List[Image.Image] = None,
|
||||||
# Progress bar
|
# Progress bar
|
||||||
progress_bar_cmd = tqdm,
|
progress_bar_cmd = tqdm,
|
||||||
):
|
):
|
||||||
@@ -102,6 +114,7 @@ class ZImagePipeline(BasePipeline):
|
|||||||
"height": height, "width": width,
|
"height": height, "width": width,
|
||||||
"seed": seed, "rand_device": rand_device,
|
"seed": seed, "rand_device": rand_device,
|
||||||
"num_inference_steps": num_inference_steps,
|
"num_inference_steps": num_inference_steps,
|
||||||
|
"image2lora_images": image2lora_images,
|
||||||
}
|
}
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
@@ -234,6 +247,131 @@ class ZImageUnit_InputImageEmbedder(PipelineUnit):
|
|||||||
return {"latents": latents, "input_latents": input_latents}
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageUnit_Image2LoRAEncode(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("image2lora_images",),
|
||||||
|
output_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"),
|
||||||
|
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",),
|
||||||
|
)
|
||||||
|
from ..core.data.operators import ImageCropAndResize
|
||||||
|
self.processor_lowres = ImageCropAndResize(height=28*8, width=28*8)
|
||||||
|
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
|
||||||
|
|
||||||
|
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
|
||||||
|
bool_mask = mask.bool()
|
||||||
|
valid_lengths = bool_mask.sum(dim=1)
|
||||||
|
selected = hidden_states[bool_mask]
|
||||||
|
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||||
|
return split_result
|
||||||
|
|
||||||
|
def encode_prompt_edit(self, pipe: ZImagePipeline, prompt, edit_image):
|
||||||
|
prompt = [prompt]
|
||||||
|
template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
drop_idx = 64
|
||||||
|
txt = [template.format(e) for e in prompt]
|
||||||
|
model_inputs = pipe.processor(text=txt, images=edit_image, padding=True, return_tensors="pt").to(pipe.device)
|
||||||
|
hidden_states = pipe.text_encoder(input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, pixel_values=model_inputs.pixel_values, image_grid_thw=model_inputs.image_grid_thw, output_hidden_states=True,)[-1]
|
||||||
|
split_hidden_states = self.extract_masked_hidden(hidden_states, model_inputs.attention_mask)
|
||||||
|
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||||
|
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||||
|
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return prompt_embeds.view(1, -1)
|
||||||
|
|
||||||
|
def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||||
|
pipe.load_models_to_device(["siglip2_image_encoder"])
|
||||||
|
embs = []
|
||||||
|
for image in images:
|
||||||
|
image = self.processor_highres(image)
|
||||||
|
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
|
||||||
|
embs = torch.stack(embs)
|
||||||
|
return embs
|
||||||
|
|
||||||
|
def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||||
|
pipe.load_models_to_device(["dinov3_image_encoder"])
|
||||||
|
embs = []
|
||||||
|
for image in images:
|
||||||
|
image = self.processor_highres(image)
|
||||||
|
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
|
||||||
|
embs = torch.stack(embs)
|
||||||
|
return embs
|
||||||
|
|
||||||
|
def encode_images_using_qwenvl(self, pipe: ZImagePipeline, images: list[Image.Image], highres=False):
|
||||||
|
pipe.load_models_to_device(["text_encoder"])
|
||||||
|
embs = []
|
||||||
|
for image in images:
|
||||||
|
image = self.processor_highres(image) if highres else self.processor_lowres(image)
|
||||||
|
embs.append(self.encode_prompt_edit(pipe, prompt="", edit_image=image))
|
||||||
|
embs = torch.stack(embs)
|
||||||
|
return embs
|
||||||
|
|
||||||
|
def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||||
|
if images is None:
|
||||||
|
return {}
|
||||||
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
|
||||||
|
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
|
||||||
|
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
|
||||||
|
residual = None
|
||||||
|
residual_highres = None
|
||||||
|
return x, residual, residual_highres
|
||||||
|
|
||||||
|
def process(self, pipe: ZImagePipeline, image2lora_images):
|
||||||
|
if image2lora_images is None:
|
||||||
|
return {}
|
||||||
|
x, residual, residual_highres = self.encode_images(pipe, image2lora_images)
|
||||||
|
return {"image2lora_x": x, "image2lora_residual": residual, "image2lora_residual_highres": residual_highres}
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageUnit_Image2LoRADecode(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("image2lora_x", "image2lora_residual", "image2lora_residual_highres"),
|
||||||
|
output_params=("lora",),
|
||||||
|
onload_model_names=("image2lora_style",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: ZImagePipeline, image2lora_x, image2lora_residual, image2lora_residual_highres):
|
||||||
|
if image2lora_x is None:
|
||||||
|
return {}
|
||||||
|
loras = []
|
||||||
|
if pipe.image2lora_style is not None:
|
||||||
|
pipe.load_models_to_device(["image2lora_style"])
|
||||||
|
for x in image2lora_x:
|
||||||
|
loras.append(pipe.image2lora_style(x=x, residual=None))
|
||||||
|
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
|
||||||
|
return {"lora": lora}
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageUnit_Image2LoRATraining(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("lora",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: ZImagePipeline, lora):
|
||||||
|
if lora is None:
|
||||||
|
return {}
|
||||||
|
pipe.clear_lora()
|
||||||
|
pipe.load_lora(pipe.dit, state_dict=lora)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageUnit_DelUnusedParams(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(take_over=True)
|
||||||
|
|
||||||
|
def process(self, pipe: ZImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
if not pipe.scheduler.training:
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
if "input_image" in inputs_shared: inputs_shared.pop("input_image")
|
||||||
|
if "image2lora_images" in inputs_shared: inputs_shared.pop("image2lora_images")
|
||||||
|
if "noise" in inputs_shared: inputs_shared.pop("noise")
|
||||||
|
if "latents" in inputs_shared: inputs_shared.pop("latents")
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
def model_fn_z_image(
|
def model_fn_z_image(
|
||||||
dit: ZImageDiT,
|
dit: ZImageDiT,
|
||||||
latents=None,
|
latents=None,
|
||||||
|
|||||||
21
prepare.py
Normal file
21
prepare.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from diffsynth import load_state_dict, skip_model_initialization
|
||||||
|
from diffsynth.models.z_image_image2lora import ZImageImage2LoRAModel
|
||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
|
||||||
|
import torch, os
|
||||||
|
from PIL import Image
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
|
||||||
|
model = ZImageImage2LoRAModel(compress_dim=256).to("cuda").to(torch.bfloat16)
|
||||||
|
model.initialize_weights()
|
||||||
|
os.makedirs("models/train/Z-Image-i2L_v12", exist_ok=True)
|
||||||
|
save_file(model.state_dict(), "models/train/Z-Image-i2L_v12/model.safetensors")
|
||||||
|
|
||||||
|
# check loading
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig("models/train/Z-Image-i2L_v12/model.safetensors"),
|
||||||
|
],
|
||||||
|
)
|
||||||
14
run.sh
Normal file
14
run.sh
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
accelerate launch train.py \
|
||||||
|
--dataset_base_path "" \
|
||||||
|
--dataset_metadata_path data/metadata_sampled_110w.csv \
|
||||||
|
--model_paths "models/train/Z-Image-i2L_v12/model.safetensors" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 10000 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.image2lora_style." \
|
||||||
|
--output_path "./models/train/Z-Image-i2L_v13" \
|
||||||
|
--trainable_models "image2lora_style" \
|
||||||
|
--dataset_num_workers 2 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--save_steps 1000
|
||||||
58
test.py
Normal file
58
test.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from diffsynth.pipelines.z_image import (
|
||||||
|
ZImagePipeline, ModelConfig,
|
||||||
|
ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
|
||||||
|
)
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cuda",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Base-1211_Temp", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config),
|
||||||
|
ModelConfig("models/train/Z-Image-i2L_v13/step-58000.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=80,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
snapshot_download(
|
||||||
|
model_id="DiffSynth-Studio/Qwen-Image-i2L",
|
||||||
|
allow_file_pattern="assets/style/*",
|
||||||
|
local_dir="data/examples"
|
||||||
|
)
|
||||||
|
for style_id in range(1, 5):
|
||||||
|
images = [Image.open(f"data/examples/assets/style/{style_id}/{i}.jpg") for i in range(4)]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
|
||||||
|
lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
|
||||||
|
|
||||||
|
prompt = "a cat"
|
||||||
|
pipe.clear_lora()
|
||||||
|
pipe.load_lora(pipe.dit, state_dict=lora, alpha=1)
|
||||||
|
image = pipe(prompt=prompt, seed=123, cfg_scale=4, num_inference_steps=50)
|
||||||
|
image.save(f"image_lora_{style_id}.jpg")
|
||||||
|
|
||||||
|
pipe.clear_lora()
|
||||||
|
image = pipe(prompt=prompt, seed=123, cfg_scale=4, num_inference_steps=50)
|
||||||
|
image.save("image_base.jpg")
|
||||||
181
train.py
Normal file
181
train.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
import torch, os, argparse, accelerate, copy
|
||||||
|
from diffsynth.core import UnifiedDataset
|
||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
from diffsynth.pipelines.z_image import ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode, ZImageUnit_Image2LoRATraining
|
||||||
|
from diffsynth.diffusion import *
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageTrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
tokenizer_path=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
fp8_models=None,
|
||||||
|
offload_models=None,
|
||||||
|
device="cpu",
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": device,
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": device,
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": device,
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": device,
|
||||||
|
}
|
||||||
|
self.pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device=device,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Base-1211_Temp", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"),
|
||||||
|
ModelConfig(model_paths),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
self.pipe.vram_management_enabled = False
|
||||||
|
self.pipe.units = self.pipe.units + [
|
||||||
|
ZImageUnit_Image2LoRAEncode(),
|
||||||
|
ZImageUnit_Image2LoRADecode(),
|
||||||
|
ZImageUnit_Image2LoRATraining(),
|
||||||
|
]
|
||||||
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
|
||||||
|
# Training mode
|
||||||
|
self.switch_pipe_to_training_mode(
|
||||||
|
self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
|
preset_lora_path, preset_lora_model,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
self.fp8_models = fp8_models
|
||||||
|
self.task = task
|
||||||
|
self.task_to_loss = {
|
||||||
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"direct_distill:data_process": lambda pipe, *args: args,
|
||||||
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
}
|
||||||
|
if task == "trajectory_imitation":
|
||||||
|
# This is an experimental feature.
|
||||||
|
# We may remove it in the future.
|
||||||
|
self.loss_fn = TrajectoryImitationLoss()
|
||||||
|
self.task_to_loss["trajectory_imitation"] = self.loss_fn
|
||||||
|
self.pipe_teacher = copy.deepcopy(self.pipe)
|
||||||
|
self.pipe_teacher.requires_grad_(False)
|
||||||
|
|
||||||
|
def get_pipeline_inputs(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
"image2lora_images": data["image"],
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
}
|
||||||
|
if self.task == "trajectory_imitation":
|
||||||
|
inputs_shared["cfg_scale"] = 2
|
||||||
|
inputs_shared["teacher"] = self.pipe_teacher
|
||||||
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||||
|
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def z_image_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser = add_general_config(parser)
|
||||||
|
parser = add_image_size_config(parser)
|
||||||
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = z_image_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
accelerator = accelerate.Accelerator(
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
|
)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=16,
|
||||||
|
width_division_factor=16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = ZImageTrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
tokenizer_path=args.tokenizer_path,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_checkpoint=args.lora_checkpoint,
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
preset_lora_model=args.preset_lora_model,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
fp8_models=args.fp8_models,
|
||||||
|
offload_models=args.offload_models,
|
||||||
|
task=args.task,
|
||||||
|
device=accelerator.device,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
args.output_path,
|
||||||
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
)
|
||||||
|
launcher_map = {
|
||||||
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"direct_distill:data_process": launch_data_process_task,
|
||||||
|
"sft": launch_training_task,
|
||||||
|
"sft:train": launch_training_task,
|
||||||
|
"direct_distill": launch_training_task,
|
||||||
|
"direct_distill:train": launch_training_task,
|
||||||
|
"trajectory_imitation": launch_training_task,
|
||||||
|
}
|
||||||
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
Reference in New Issue
Block a user