z image distill

This commit is contained in:
Artiprocher
2025-12-03 11:20:49 +08:00
parent 5065c9ef6a
commit 4a80e9c179
8 changed files with 185 additions and 26 deletions

View File

@@ -13,27 +13,3 @@ accelerate launch examples/z_image/model_training/train.py \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8
# Z-Image-Turbo is a distilled model.
# After training, it loses its distillation-based acceleration capability,
# leading to degraded generation quality at fewer inference steps.
# This issue can be mitigated by using a pre-trained LoRA model to assist the training process.
# accelerate launch examples/z_image/model_training/train.py \
# --dataset_base_path data/example_image_dataset \
# --dataset_metadata_path data/example_image_dataset/metadata.csv \
# --max_pixels 1048576 \
# --dataset_repeat 50 \
# --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/Z-Image-Turbo_lora" \
# --lora_base_model "dit" \
# --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
# --lora_rank 32 \
# --preset_lora_path "models/ostris/zimage_turbo_training_adapter/zimage_turbo_training_adapter_v1.safetensors" \
# --preset_lora_model "dit" \
# --use_gradient_checkpointing \
# --dataset_num_workers 8

View File

@@ -0,0 +1,23 @@
# Z-Image-Turbo is a distilled model.
# After training, it loses its distillation-based acceleration capability,
# leading to degraded generation quality at fewer inference steps.
# This issue can be mitigated by using a pre-trained LoRA model to assist the training process.
# https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter
accelerate launch examples/z_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Z-Image-Turbo_lora_differential" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
--lora_rank 32 \
--preset_lora_path "models/ostris/zimage_turbo_training_adapter/zimage_turbo_training_adapter_v1.safetensors" \
--preset_lora_model "dit" \
--use_gradient_checkpointing \
--dataset_num_workers 8

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo_lora_differential/epoch-4.safetensors")
prompt = "a dog"
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
image.save("image.jpg")

View File

@@ -0,0 +1,18 @@
accelerate launch examples/z_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Z-Image-Turbo_lora_distill" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
--lora_rank 32 \
--lora_checkpoint "./models/train/Z-Image-Turbo_lora/epoch-4.safetensors" \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--task "trajectory_imitation" \
--save_steps 10

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
import torch
pipe = ZImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo_lora_distill/step-20.safetensors")
prompt = "a dog"
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
image.save("image.jpg")

View File

@@ -1,4 +1,4 @@
import torch, os, argparse, accelerate
import torch, os, argparse, accelerate, copy
from diffsynth.core import UnifiedDataset
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
from diffsynth.diffusion import *
@@ -50,6 +50,13 @@ class ZImageTrainingModule(DiffusionTrainingModule):
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
}
if task == "trajectory_imitation":
# This is an experimental feature.
# We may remove it in the future.
self.loss_fn = TrajectoryImitationLoss()
self.task_to_loss["trajectory_imitation"] = self.loss_fn
self.pipe_teacher = copy.deepcopy(self.pipe)
self.pipe_teacher.requires_grad_(False)
def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"]}
@@ -67,6 +74,9 @@ class ZImageTrainingModule(DiffusionTrainingModule):
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
if self.task == "trajectory_imitation":
inputs_shared["cfg_scale"] = 2
inputs_shared["teacher"] = self.pipe_teacher
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
@@ -138,5 +148,6 @@ if __name__ == "__main__":
"sft:train": launch_training_task,
"direct_distill": launch_training_task,
"direct_distill:train": launch_training_task,
"trajectory_imitation": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)