support z-image controlnet

This commit is contained in:
Artiprocher
2026-01-07 15:56:53 +08:00
parent 32449a6aa0
commit bac39b1cd2
28 changed files with 868 additions and 11 deletions

View File

@@ -540,6 +540,12 @@ z_image_series = [
"model_name": "siglip_vision_model_428m", "model_name": "siglip_vision_model_428m",
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M", "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
}, },
{
# ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
"model_hash": "1677708d40029ab380a95f6c731a57d7",
"model_name": "z_image_controlnet",
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
}
] ]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series

View File

@@ -195,4 +195,8 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
}, },
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
} }

View File

@@ -0,0 +1,154 @@
from .z_image_dit import ZImageTransformerBlock
from ..core.gradient import gradient_checkpoint_forward
from torch.nn.utils.rnn import pad_sequence
import torch
from torch import nn
class ZImageControlTransformerBlock(ZImageTransformerBlock):
def __init__(
self,
layer_id: int = 1000,
dim: int = 3840,
n_heads: int = 30,
n_kv_heads: int = 30,
norm_eps: float = 1e-5,
qk_norm: bool = True,
modulation = True,
block_id = 0
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
self.after_proj = nn.Linear(self.dim, self.dim)
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class ZImageControlNet(torch.nn.Module):
def __init__(
self,
control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
control_in_dim=33,
dim=3840,
n_refiner_layers=2,
):
super().__init__()
self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places])
self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)})
self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)])
self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
def forward_layers(
self,
x,
cap_feats,
control_context,
control_context_item_seqlens,
kwargs,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
bsz = len(control_context)
# unified
cap_item_seqlens = [len(_) for _ in cap_feats]
control_context_unified = []
for i in range(bsz):
control_context_len = control_context_item_seqlens[i]
cap_len = cap_item_seqlens[i]
control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]]))
c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
for layer in self.control_layers:
c = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
c=c, **new_kwargs
)
hints = torch.unbind(c)[:-1]
return hints
def forward_refiner(
self,
dit,
x,
cap_feats,
control_context,
kwargs,
t=None,
patch_size=2,
f_patch_size=1,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
):
# embeddings
bsz = len(control_context)
device = control_context[0].device
(
control_context,
control_context_size,
control_context_pos_ids,
control_context_inner_pad_mask,
) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0))
# control_context embed & refine
control_context_item_seqlens = [len(_) for _ in control_context]
assert all(_ % 2 == 0 for _ in control_context_item_seqlens)
control_context_max_item_seqlen = max(control_context_item_seqlens)
control_context = torch.cat(control_context, dim=0)
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
# Match t_embedder output dtype to control_context for layerwise casting compatibility
adaln_input = t.type_as(control_context)
control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device)
control_context = list(control_context.split(control_context_item_seqlens, dim=0))
control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0))
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0)
control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(control_context_item_seqlens):
control_context_attn_mask[i, :seq_len] = 1
c = control_context
# arguments
new_kwargs = dict(
x=x,
attn_mask=control_context_attn_mask,
freqs_cis=control_context_freqs_cis,
adaln_input=adaln_input,
)
new_kwargs.update(kwargs)
for layer in self.control_noise_refiner:
c = gradient_checkpoint_forward(
layer,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
c=c, **new_kwargs
)
hints = torch.unbind(c)[:-1]
control_context = torch.unbind(c)[-1]
return hints, control_context, control_context_item_seqlens

View File

@@ -610,6 +610,72 @@ class ZImageDiT(nn.Module):
# all_cap_pad_mask, # all_cap_pad_mask,
# ) # )
def patchify_controlnet(
self,
all_image: List[torch.Tensor],
patch_size: int = 2,
f_patch_size: int = 1,
cap_padding_len: int = None,
):
pH = pW = patch_size
pF = f_patch_size
device = all_image[0].device
all_image_out = []
all_image_size = []
all_image_pos_ids = []
all_image_pad_mask = []
for i, image in enumerate(all_image):
### Process Image
C, F, H, W = image.size()
all_image_size.append((F, H, W))
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
image_ori_pos_ids = self.create_coordinate_grid(
size=(F_tokens, H_tokens, W_tokens),
start=(cap_padding_len + 1, 0, 0),
device=device,
).flatten(0, 2)
image_padding_pos_ids = (
self.create_coordinate_grid(
size=(1, 1, 1),
start=(0, 0, 0),
device=device,
)
.flatten(0, 2)
.repeat(image_padding_len, 1)
)
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
all_image_pos_ids.append(image_padded_pos_ids)
# pad mask
all_image_pad_mask.append(
torch.cat(
[
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
],
dim=0,
)
)
# padded feature
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat)
return (
all_image_out,
all_image_size,
all_image_pos_ids,
all_image_pad_mask,
)
def _prepare_sequence( def _prepare_sequence(
self, self,
feats: List[torch.Tensor], feats: List[torch.Tensor],

View File

@@ -16,6 +16,7 @@ from ..models.z_image_text_encoder import ZImageTextEncoder
from ..models.z_image_dit import ZImageDiT from ..models.z_image_dit import ZImageDiT
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
from ..models.z_image_controlnet import ZImageControlNet
class ZImagePipeline(BasePipeline): class ZImagePipeline(BasePipeline):
@@ -31,8 +32,9 @@ class ZImagePipeline(BasePipeline):
self.vae_encoder: FluxVAEEncoder = None self.vae_encoder: FluxVAEEncoder = None
self.vae_decoder: FluxVAEDecoder = None self.vae_decoder: FluxVAEDecoder = None
self.image_encoder: Siglip2ImageEncoder428M = None self.image_encoder: Siglip2ImageEncoder428M = None
self.controlnet: ZImageControlNet = None
self.tokenizer: AutoTokenizer = None self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",) self.in_iteration_models = ("dit", "controlnet")
self.units = [ self.units = [
ZImageUnit_ShapeChecker(), ZImageUnit_ShapeChecker(),
ZImageUnit_PromptEmbedder(), ZImageUnit_PromptEmbedder(),
@@ -41,6 +43,7 @@ class ZImagePipeline(BasePipeline):
ZImageUnit_EditImageAutoResize(), ZImageUnit_EditImageAutoResize(),
ZImageUnit_EditImageEmbedderVAE(), ZImageUnit_EditImageEmbedderVAE(),
ZImageUnit_EditImageEmbedderSiglip(), ZImageUnit_EditImageEmbedderSiglip(),
ZImageUnit_PAIControlNet(),
] ]
self.model_fn = model_fn_z_image self.model_fn = model_fn_z_image
@@ -63,6 +66,7 @@ class ZImagePipeline(BasePipeline):
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder") pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder") pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m") pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
if tokenizer_config is not None: if tokenizer_config is not None:
tokenizer_config.download_if_necessary() tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
@@ -94,6 +98,8 @@ class ZImagePipeline(BasePipeline):
# Steps # Steps
num_inference_steps: int = 8, num_inference_steps: int = 8,
sigma_shift: float = None, sigma_shift: float = None,
# ControlNet
controlnet_inputs: List[ControlNetInput] = None,
# Progress bar # Progress bar
progress_bar_cmd = tqdm, progress_bar_cmd = tqdm,
): ):
@@ -114,6 +120,7 @@ class ZImagePipeline(BasePipeline):
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
"controlnet_inputs": controlnet_inputs,
} }
for unit in self.units: for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -331,7 +338,9 @@ class ZImageUnit_EditImageAutoResize(PipelineUnit):
if edit_image_auto_resize is None or not edit_image_auto_resize: if edit_image_auto_resize is None or not edit_image_auto_resize:
return {} return {}
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16) operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
edit_image = operator(edit_image) if not isinstance(edit_image, list):
edit_image = [edit_image]
edit_image = [operator(i) for i in edit_image]
return {"edit_image": edit_image} return {"edit_image": edit_image}
@@ -376,8 +385,49 @@ class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
return {"image_latents": image_latents} return {"image_latents": image_latents}
class ZImageUnit_PAIControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "height", "width"),
output_params=("control_context", "control_scale"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
if controlnet_inputs is None:
return {}
if len(controlnet_inputs) != 1:
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
controlnet_input = controlnet_inputs[0]
pipe.load_models_to_device(self.onload_model_names)
control_image = controlnet_input.image
if control_image is not None:
control_image = pipe.preprocess_image(control_image)
control_latents = pipe.vae_encoder(control_image)
else:
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
inpaint_mask = controlnet_input.inpaint_mask
if inpaint_mask is not None:
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
inpaint_image = controlnet_input.inpaint_image
inpaint_image = pipe.preprocess_image(inpaint_image)
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
else:
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
inpaint_latent = pipe.vae_encoder(inpaint_image)
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
return {"control_context": control_context, "control_scale": controlnet_input.scale}
def model_fn_z_image( def model_fn_z_image(
dit: ZImageDiT, dit: ZImageDiT,
controlnet: ZImageControlNet = None,
latents=None, latents=None,
timestep=None, timestep=None,
prompt_embeds=None, prompt_embeds=None,
@@ -393,13 +443,14 @@ def model_fn_z_image(
if dit.siglip_embedder is None: if dit.siglip_embedder is None:
return model_fn_z_image_turbo( return model_fn_z_image_turbo(
dit, dit,
latents, controlnet=controlnet,
timestep, latents=latents,
prompt_embeds, timestep=timestep,
image_embeds, prompt_embeds=prompt_embeds,
image_latents, image_embeds=image_embeds,
use_gradient_checkpointing, image_latents=image_latents,
use_gradient_checkpointing_offload, use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
**kwargs, **kwargs,
) )
latents = [rearrange(latents, "B C H W -> C B H W")] latents = [rearrange(latents, "B C H W -> C B H W")]
@@ -431,11 +482,14 @@ def model_fn_z_image(
def model_fn_z_image_turbo( def model_fn_z_image_turbo(
dit: ZImageDiT, dit: ZImageDiT,
controlnet: ZImageControlNet = None,
latents=None, latents=None,
timestep=None, timestep=None,
prompt_embeds=None, prompt_embeds=None,
image_embeds=None, image_embeds=None,
image_latents=None, image_latents=None,
control_context=None,
control_scale=None,
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs, **kwargs,
@@ -460,11 +514,17 @@ def model_fn_z_image_turbo(
# Noise refine # Noise refine
x = dit.all_x_embedder["2-1"](x) x = dit.all_x_embedder["2-1"](x)
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0)) x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
x = rearrange(x, "L C -> 1 L C") x = rearrange(x, "L C -> 1 L C")
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C") x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
for layer in dit.noise_refiner: if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1)
for layer_id, layer in enumerate(dit.noise_refiner):
x = gradient_checkpoint_forward( x = gradient_checkpoint_forward(
layer, layer,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
@@ -474,6 +534,8 @@ def model_fn_z_image_turbo(
freqs_cis=x_freqs_cis, freqs_cis=x_freqs_cis,
adaln_input=t_noisy, adaln_input=t_noisy,
) )
if control_context is not None:
x = x + refiner_hints[layer_id] * control_scale
# Prompt refine # Prompt refine
cap_feats = dit.cap_embedder(cap_feats) cap_feats = dit.cap_embedder(cap_feats)
@@ -495,7 +557,13 @@ def model_fn_z_image_turbo(
# Unified # Unified
unified = torch.cat([x, cap_feats], dim=1) unified = torch.cat([x, cap_feats], dim=1)
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1) unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
for layer in dit.layers:
if control_context is not None:
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
hints = controlnet.forward_layers(
unified, cap_feats, control_context, control_context_item_seqlens, kwargs)
for layer_id, layer in enumerate(dit.layers):
unified = gradient_checkpoint_forward( unified = gradient_checkpoint_forward(
layer, layer,
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing=use_gradient_checkpointing,
@@ -505,6 +573,9 @@ def model_fn_z_image_turbo(
freqs_cis=unified_freqs_cis, freqs_cis=unified_freqs_cis,
adaln_input=t_noisy, adaln_input=t_noisy,
) )
if control_context is not None:
if layer_id in controlnet.control_layers_mapping:
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
# Output # Output
unified = dit.all_final_layer["2-1"](unified, t_noisy) unified = dit.all_final_layer["2-1"](unified, t_noisy)

View File

@@ -9,5 +9,6 @@ class ControlNetInput:
start: float = 1.0 start: float = 1.0
end: float = 0.0 end: float = 0.0
image: Image.Image = None image: Image.Image = None
inpaint_image: Image.Image = None
inpaint_mask: Image.Image = None inpaint_mask: Image.Image = None
processor_id: str = None processor_id: str = None

View File

@@ -0,0 +1,27 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern="data/examples/upscale/low_res.png"
)
controlnet_image = Image.open("data/examples/upscale/low_res.png").resize((1024, 1024))
prompt = "这是一张充满都市气息的户外人物肖像照片。画面中是一位年轻男性他展现出时尚而自信的形象。人物拥有精心打理的短发发型两侧修剪得较短顶部保留一定长度呈现出流行的Undercut造型。他佩戴着一副时尚的浅色墨镜或透明镜框眼镜为整体造型增添了潮流感。脸上洋溢着温和友善的笑容神情放松自然给人以阳光开朗的印象。他身穿一件经典的牛仔外套这件单品永不过时展现出休闲又有型的穿衣风格。牛仔外套的蓝色调与整体氛围十分协调领口处隐约可见内搭的衣物。照片的背景是典型的城市街景可以看到模糊的建筑物、街道和行人营造出繁华都市的氛围。背景经过了恰当的虚化处理使人物主体更加突出。光线明亮而柔和可能是白天的自然光为照片带来清新通透的视觉效果。整张照片构图专业景深控制得当完美捕捉了一个现代都市年轻人充满活力和自信的瞬间展现出积极向上的生活态度。"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
image.save("image_tile.jpg")

View File

@@ -0,0 +1,40 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
# Control
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="depth/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024))
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
image.save("image_control.jpg")
# Inpaint
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="inpaint/*.jpg"
)
inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
prompt = "一只戴着墨镜的猫"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)])
image.save("image_inpaint.jpg")

View File

@@ -0,0 +1,46 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
# Control
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="depth/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024))
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(
prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)],
num_inference_steps=30,
)
image.save("image_control.jpg")
# Inpaint
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="inpaint/*.jpg"
)
inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
prompt = "一只戴着墨镜的猫"
image = pipe(
prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)],
num_inference_steps=30,
)
image.save("image_inpaint.jpg")

View File

@@ -0,0 +1,37 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern="data/examples/upscale/low_res.png"
)
controlnet_image = Image.open("data/examples/upscale/low_res.png").resize((1024, 1024))
prompt = "这是一张充满都市气息的户外人物肖像照片。画面中是一位年轻男性他展现出时尚而自信的形象。人物拥有精心打理的短发发型两侧修剪得较短顶部保留一定长度呈现出流行的Undercut造型。他佩戴着一副时尚的浅色墨镜或透明镜框眼镜为整体造型增添了潮流感。脸上洋溢着温和友善的笑容神情放松自然给人以阳光开朗的印象。他身穿一件经典的牛仔外套这件单品永不过时展现出休闲又有型的穿衣风格。牛仔外套的蓝色调与整体氛围十分协调领口处隐约可见内搭的衣物。照片的背景是典型的城市街景可以看到模糊的建筑物、街道和行人营造出繁华都市的氛围。背景经过了恰当的虚化处理使人物主体更加突出。光线明亮而柔和可能是白天的自然光为照片带来清新通透的视觉效果。整张照片构图专业景深控制得当完美捕捉了一个现代都市年轻人充满活力和自信的瞬间展现出积极向上的生活态度。"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
image.save("image_tile.jpg")

View File

@@ -0,0 +1,50 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
# Control
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="depth/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024))
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
image.save("image_control.jpg")
# Inpaint
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="inpaint/*.jpg"
)
inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
prompt = "一只戴着墨镜的猫"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)])
image.save("image_inpaint.jpg")

View File

@@ -0,0 +1,56 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from modelscope import dataset_snapshot_download
from PIL import Image
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
# Control
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="depth/image_1.jpg"
)
controlnet_image = Image.open("data/example_image_dataset/depth/image_1.jpg").resize((1024, 1024))
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(
prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)],
num_inference_steps=30,
)
image.save("image_control.jpg")
# Inpaint
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/example_image_dataset",
local_dir="./data/example_image_dataset",
allow_file_pattern="inpaint/*.jpg"
)
inpaint_image = Image.open("./data/example_image_dataset/inpaint/image_1.jpg").convert("RGB").resize((1024, 1024))
inpaint_mask = Image.open("./data/example_image_dataset/inpaint/mask.jpg").convert("RGB").resize((1024, 1024))
prompt = "一只戴着墨镜的猫"
image = pipe(
prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(inpaint_image=inpaint_image, inpaint_mask=inpaint_mask, scale=0.7)],
num_inference_steps=30,
)
image.save("image_inpaint.jpg")

View File

@@ -1,4 +1,5 @@
# This example is tested on 8*A100 # This example is tested on 8*A100
# Text to image training
accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \ accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \ --dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \ --dataset_metadata_path data/example_image_dataset/metadata.csv \
@@ -12,3 +13,20 @@ accelerate launch --config_file examples/z_image/model_training/full/accelerate_
--trainable_models "dit" \ --trainable_models "dit" \
--use_gradient_checkpointing \ --use_gradient_checkpointing \
--dataset_num_workers 8 --dataset_num_workers 8
# Image(s) to image training
# accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 400 \
# --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
# --learning_rate 1e-5 \
# --num_epochs 2 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/Z-Image-Omni-Base_full_edit" \
# --trainable_models "dit" \
# --use_gradient_checkpointing \
# --dataset_num_workers 8

View File

@@ -0,0 +1,15 @@
accelerate launch examples/z_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.controlnet." \
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_full" \
--trainable_models "controlnet" \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8

View File

@@ -0,0 +1,15 @@
accelerate launch examples/z_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.controlnet." \
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_full" \
--trainable_models "controlnet" \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8

View File

@@ -0,0 +1,15 @@
accelerate launch examples/z_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.controlnet." \
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_full" \
--trainable_models "controlnet" \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8

View File

@@ -1,3 +1,4 @@
# Text to image training
accelerate launch examples/z_image/model_training/train.py \ accelerate launch examples/z_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \ --dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \ --dataset_metadata_path data/example_image_dataset/metadata.csv \
@@ -13,3 +14,22 @@ accelerate launch examples/z_image/model_training/train.py \
--lora_rank 32 \ --lora_rank 32 \
--use_gradient_checkpointing \ --use_gradient_checkpointing \
--dataset_num_workers 8 --dataset_num_workers 8
# Image(s) to image training
# accelerate launch examples/z_image/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
# --data_file_keys "image,edit_image" \
# --extra_inputs "edit_image" \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/Z-Image-Omni-Base_lora_edit" \
# --lora_base_model "dit" \
# --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
# --lora_rank 32 \
# --use_gradient_checkpointing \
# --dataset_num_workers 8

View File

@@ -0,0 +1,17 @@
accelerate launch examples/z_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_upscale.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
--lora_rank 32 \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8

View File

@@ -0,0 +1,17 @@
accelerate launch examples/z_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
--lora_rank 32 \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8

View File

@@ -0,0 +1,17 @@
accelerate launch examples/z_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata_controlnet_canny.csv \
--data_file_keys "image,controlnet_image" \
--max_pixels 1048576 \
--dataset_repeat 100 \
--model_id_with_origin_paths "PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1:Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors,Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
--lora_rank 32 \
--extra_inputs "controlnet_image" \
--use_gradient_checkpointing \
--dataset_num_workers 8

View File

@@ -14,8 +14,20 @@ pipe = ZImagePipeline.from_pretrained(
], ],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
) )
state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full/epoch-1.safetensors", torch_dtype=torch.bfloat16) state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict) pipe.dit.load_state_dict(state_dict)
prompt = "a dog" prompt = "a dog"
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
image.save("image.jpg") image.save("image.jpg")
# Edit
# state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full_edit/epoch-1.safetensors", torch_dtype=torch.bfloat16)
# pipe.dit.load_state_dict(state_dict)
# prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
# images = [
# Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
# Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
# ]
# image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4, edit_image=images)
# image.save("image.jpg")

View File

@@ -0,0 +1,24 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
from PIL import Image
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_full/epoch-1.safetensors")
pipe.controlnet.load_state_dict(state_dict)
controlnet_image = Image.open("data/example_image_dataset/upscale/image_1.jpg").resize((1024, 1024))
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=1)])
image.save("image_tile.jpg")

View File

@@ -0,0 +1,24 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
from PIL import Image
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_full/epoch-1.safetensors")
pipe.controlnet.load_state_dict(state_dict)
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024))
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
image.save("image_control.jpg")

View File

@@ -0,0 +1,24 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
from PIL import Image
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_full/epoch-1.safetensors")
pipe.controlnet.load_state_dict(state_dict)
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024))
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
image.save("image_control.jpg")

View File

@@ -1,4 +1,5 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
from PIL import Image
import torch import torch
@@ -13,7 +14,18 @@ pipe = ZImagePipeline.from_pretrained(
], ],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
) )
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora/epoch-4.safetensors") pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora/epoch-4.safetensors")
prompt = "a dog" prompt = "a dog"
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4) image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
image.save("image.jpg") image.save("image.jpg")
# Edit
# pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora_edit/epoch-4.safetensors")
# prompt = "Change the color of the dress in Figure 1 to the color shown in Figure 2."
# images = [
# Image.open("data/example_image_dataset/edit/image1.jpg").resize((1024, 1024)),
# Image.open("data/example_image_dataset/edit/image_color.jpg").resize((1024, 1024)),
# ]
# image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4, edit_image=images)
# image.save("image.jpg")

View File

@@ -0,0 +1,23 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
from PIL import Image
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps_lora/epoch-4.safetensors")
controlnet_image = Image.open("data/example_image_dataset/upscale/image_1.jpg").resize((1024, 1024))
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=1)])
image.save("image_tile.jpg")

View File

@@ -0,0 +1,23 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
from PIL import Image
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps_lora/epoch-4.safetensors")
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024))
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
image.save("image_control.jpg")

View File

@@ -0,0 +1,23 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig, ControlNetInput
from diffsynth import load_state_dict
from PIL import Image
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo-Fun-Controlnet-Union-2.1_lora/epoch-4.safetensors")
controlnet_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1024, 1024))
prompt = "a dog"
image = pipe(prompt=prompt, seed=0, height=1024, width=1024, controlnet_inputs=[ControlNetInput(image=controlnet_image, scale=0.7)])
image.save("image_control.jpg")