mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 16:18:13 +00:00
z-image
This commit is contained in:
@@ -9,7 +9,7 @@ accelerate launch examples/z_image/model_training/train.py \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Z-Image-Turbo_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8
|
||||
|
||||
@@ -62,7 +62,6 @@ class ZImageTrainingModule(DiffusionTrainingModule):
|
||||
"width": data["image"].size[0],
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"embedded_guidance": 1.0,
|
||||
"cfg_scale": 1,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
@@ -80,7 +79,7 @@ class ZImageTrainingModule(DiffusionTrainingModule):
|
||||
return loss
|
||||
|
||||
|
||||
def qwen_image_parser():
|
||||
def z_image_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser = add_general_config(parser)
|
||||
parser = add_image_size_config(parser)
|
||||
@@ -89,7 +88,7 @@ def qwen_image_parser():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = qwen_image_parser()
|
||||
parser = z_image_parser()
|
||||
args = parser.parse_args()
|
||||
accelerator = accelerate.Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
|
||||
@@ -12,7 +12,7 @@ pipe = ZImagePipeline.from_pretrained(
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "/models/train/Z-Image-Turbo_lora/epoch-4.safetensors")
|
||||
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo_lora/epoch-4.safetensors")
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||
image.save("image.jpg")
|
||||
|
||||
Reference in New Issue
Block a user