mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
z-image
This commit is contained in:
@@ -12,6 +12,7 @@ class FlowMatchScheduler():
|
||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||
@@ -111,11 +112,12 @@ class FlowMatchScheduler():
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
def set_training_weight(self, num_inference_steps):
|
||||
def set_training_weight(self):
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
||||
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
||||
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||
@@ -125,7 +127,7 @@ class FlowMatchScheduler():
|
||||
**kwargs,
|
||||
)
|
||||
if training:
|
||||
self.set_training_weight(num_inference_steps)
|
||||
self.set_training_weight()
|
||||
self.training = True
|
||||
else:
|
||||
self.training = False
|
||||
|
||||
@@ -3,8 +3,8 @@ import torch
|
||||
|
||||
|
||||
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * pipe.scheduler.num_train_timesteps)
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * pipe.scheduler.num_train_timesteps)
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
@@ -9,7 +9,7 @@ accelerate launch examples/z_image/model_training/train.py \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Z-Image-Turbo_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8
|
||||
|
||||
@@ -62,7 +62,6 @@ class ZImageTrainingModule(DiffusionTrainingModule):
|
||||
"width": data["image"].size[0],
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"embedded_guidance": 1.0,
|
||||
"cfg_scale": 1,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
@@ -80,7 +79,7 @@ class ZImageTrainingModule(DiffusionTrainingModule):
|
||||
return loss
|
||||
|
||||
|
||||
def qwen_image_parser():
|
||||
def z_image_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser = add_general_config(parser)
|
||||
parser = add_image_size_config(parser)
|
||||
@@ -89,7 +88,7 @@ def qwen_image_parser():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = qwen_image_parser()
|
||||
parser = z_image_parser()
|
||||
args = parser.parse_args()
|
||||
accelerator = accelerate.Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
|
||||
@@ -12,7 +12,7 @@ pipe = ZImagePipeline.from_pretrained(
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "/models/train/Z-Image-Turbo_lora/epoch-4.safetensors")
|
||||
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo_lora/epoch-4.safetensors")
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||
image.save("image.jpg")
|
||||
|
||||
Reference in New Issue
Block a user