mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
z-image
This commit is contained in:
@@ -12,6 +12,7 @@ class FlowMatchScheduler():
|
|||||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||||
|
self.num_train_timesteps = 1000
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
@@ -111,11 +112,12 @@ class FlowMatchScheduler():
|
|||||||
timesteps = sigmas * num_train_timesteps
|
timesteps = sigmas * num_train_timesteps
|
||||||
return sigmas, timesteps
|
return sigmas, timesteps
|
||||||
|
|
||||||
def set_training_weight(self, num_inference_steps):
|
def set_training_weight(self):
|
||||||
|
steps = 1000
|
||||||
x = self.timesteps
|
x = self.timesteps
|
||||||
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||||
y_shifted = y - y.min()
|
y_shifted = y - y.min()
|
||||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
||||||
self.linear_timesteps_weights = bsmntw_weighing
|
self.linear_timesteps_weights = bsmntw_weighing
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||||
@@ -125,7 +127,7 @@ class FlowMatchScheduler():
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if training:
|
if training:
|
||||||
self.set_training_weight(num_inference_steps)
|
self.set_training_weight()
|
||||||
self.training = True
|
self.training = True
|
||||||
else:
|
else:
|
||||||
self.training = False
|
self.training = False
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * pipe.scheduler.num_train_timesteps)
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * pipe.scheduler.num_train_timesteps)
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||||
|
|
||||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ accelerate launch examples/z_image/model_training/train.py \
|
|||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
--output_path "./models/train/Z-Image-Turbo_lora" \
|
--output_path "./models/train/Z-Image-Turbo_lora" \
|
||||||
--lora_base_model "dit" \
|
--lora_base_model "dit" \
|
||||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj" \
|
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
|
||||||
--lora_rank 32 \
|
--lora_rank 32 \
|
||||||
--use_gradient_checkpointing \
|
--use_gradient_checkpointing \
|
||||||
--dataset_num_workers 8
|
--dataset_num_workers 8
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ class ZImageTrainingModule(DiffusionTrainingModule):
|
|||||||
"width": data["image"].size[0],
|
"width": data["image"].size[0],
|
||||||
# Please do not modify the following parameters
|
# Please do not modify the following parameters
|
||||||
# unless you clearly know what this will cause.
|
# unless you clearly know what this will cause.
|
||||||
"embedded_guidance": 1.0,
|
|
||||||
"cfg_scale": 1,
|
"cfg_scale": 1,
|
||||||
"rand_device": self.pipe.device,
|
"rand_device": self.pipe.device,
|
||||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
@@ -80,7 +79,7 @@ class ZImageTrainingModule(DiffusionTrainingModule):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def qwen_image_parser():
|
def z_image_parser():
|
||||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
parser = add_general_config(parser)
|
parser = add_general_config(parser)
|
||||||
parser = add_image_size_config(parser)
|
parser = add_image_size_config(parser)
|
||||||
@@ -89,7 +88,7 @@ def qwen_image_parser():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = qwen_image_parser()
|
parser = z_image_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
accelerator = accelerate.Accelerator(
|
accelerator = accelerate.Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ pipe = ZImagePipeline.from_pretrained(
|
|||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
)
|
)
|
||||||
pipe.load_lora(pipe.dit, "/models/train/Z-Image-Turbo_lora/epoch-4.safetensors")
|
pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo_lora/epoch-4.safetensors")
|
||||||
prompt = "a dog"
|
prompt = "a dog"
|
||||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||||
image.save("image.jpg")
|
image.save("image.jpg")
|
||||||
|
|||||||
Reference in New Issue
Block a user