mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support z-image-omni-base training
This commit is contained in:
@@ -97,6 +97,7 @@ class ModelConfig:
|
||||
self.reset_local_model_path()
|
||||
if self.require_downloading():
|
||||
self.download()
|
||||
if self.path is None:
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||
else:
|
||||
|
||||
@@ -90,12 +90,10 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
||||
super().__init__(config)
|
||||
self.processor = Siglip2ImageProcessorFast(
|
||||
**{
|
||||
"crop_size": None,
|
||||
"data_format": "channels_first",
|
||||
"default_to_square": True,
|
||||
"device": None,
|
||||
"disable_grouping": None,
|
||||
"do_center_crop": None,
|
||||
"do_convert_rgb": None,
|
||||
"do_normalize": True,
|
||||
"do_pad": None,
|
||||
@@ -120,7 +118,6 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
||||
"resample": 2,
|
||||
"rescale_factor": 0.00392156862745098,
|
||||
"return_tensors": None,
|
||||
"size": None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -626,7 +626,7 @@ class ZImageDiT(nn.Module):
|
||||
|
||||
# Pad token
|
||||
feats_cat = torch.cat(feats, dim=0)
|
||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token
|
||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device)
|
||||
feats = list(feats_cat.split(item_seqlens, dim=0))
|
||||
|
||||
# RoPE
|
||||
|
||||
24
examples/z_image/model_inference/Z-Image-Omni-Base.py
Normal file
24
examples/z_image/model_inference/Z-Image-Omni-Base.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4)
|
||||
image.save("image_Z-Image-Omni-Base.jpg")
|
||||
|
||||
image = Image.open("image_Z-Image-Omni-Base.jpg")
|
||||
prompt = "Change the women's clothes to white cheongsam, keep other content unchanged"
|
||||
image = pipe(prompt=prompt, edit_image=image, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
|
||||
image.save("image_edit_Z-Image-Omni-Base.jpg")
|
||||
@@ -0,0 +1,33 @@
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4)
|
||||
image.save("image_Z-Image-Omni-Base.jpg")
|
||||
|
||||
image = Image.open("image_Z-Image-Omni-Base.jpg")
|
||||
prompt = "Change the women's clothes to white cheongsam, keep other content unchanged"
|
||||
image = pipe(prompt=prompt, edit_image=image, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
|
||||
image.save("image_edit_Z-Image-Omni-Base.jpg")
|
||||
14
examples/z_image/model_training/full/Z-Image-Omni-Base.sh
Normal file
14
examples/z_image/model_training/full/Z-Image-Omni-Base.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
# This example is tested on 8*A100
|
||||
accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Z-Image-Omni-Base_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8
|
||||
15
examples/z_image/model_training/lora/Z-Image-Omni-Base.sh
Normal file
15
examples/z_image/model_training/lora/Z-Image-Omni-Base.sh
Normal file
@@ -0,0 +1,15 @@
|
||||
accelerate launch examples/z_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Tongyi-MAI/Z-Image-Omni-Base:transformer/*.safetensors,Tongyi-MAI/Z-Image-Omni-Base:siglip/model.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Z-Image-Omni-Base_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8
|
||||
@@ -0,0 +1,21 @@
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
state_dict = load_state_dict("./models/train/Z-Image-Omni-Base_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,19 @@
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Omni-Base", origin_file_pattern="siglip/model.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Omni-Base_lora/epoch-4.safetensors")
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=40, cfg_scale=4)
|
||||
image.save("image.jpg")
|
||||
Reference in New Issue
Block a user