This commit is contained in:
Artiprocher
2025-11-30 19:04:21 +08:00
parent b106458eac
commit 20cf2317e0
5 changed files with 12 additions and 11 deletions

View File

@@ -12,6 +12,7 @@ class FlowMatchScheduler():
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2, "FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
"Z-Image": FlowMatchScheduler.set_timesteps_z_image, "Z-Image": FlowMatchScheduler.set_timesteps_z_image,
}.get(template, FlowMatchScheduler.set_timesteps_flux) }.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@staticmethod @staticmethod
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None): def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
@@ -111,11 +112,12 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps timesteps = sigmas * num_train_timesteps
return sigmas, timesteps return sigmas, timesteps
def set_training_weight(self, num_inference_steps): def set_training_weight(self):
steps = 1000
x = self.timesteps x = self.timesteps
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
y_shifted = y - y.min() y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing self.linear_timesteps_weights = bsmntw_weighing
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs): def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
@@ -125,7 +127,7 @@ class FlowMatchScheduler():
**kwargs, **kwargs,
) )
if training: if training:
self.set_training_weight(num_inference_steps) self.set_training_weight()
self.training = True self.training = True
else: else:
self.training = False self.training = False

View File

@@ -3,8 +3,8 @@ import torch
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * pipe.scheduler.num_train_timesteps) max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * pipe.scheduler.num_train_timesteps) min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)

View File

@@ -9,7 +9,7 @@ accelerate launch examples/z_image/model_training/train.py \
--remove_prefix_in_ckpt "pipe.dit." \ --remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Z-Image-Turbo_lora" \ --output_path "./models/train/Z-Image-Turbo_lora" \
--lora_base_model "dit" \ --lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj" \ --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
--lora_rank 32 \ --lora_rank 32 \
--use_gradient_checkpointing \ --use_gradient_checkpointing \
--dataset_num_workers 8 --dataset_num_workers 8

View File

@@ -62,7 +62,6 @@ class ZImageTrainingModule(DiffusionTrainingModule):
"width": data["image"].size[0], "width": data["image"].size[0],
# Please do not modify the following parameters # Please do not modify the following parameters
# unless you clearly know what this will cause. # unless you clearly know what this will cause.
"embedded_guidance": 1.0,
"cfg_scale": 1, "cfg_scale": 1,
"rand_device": self.pipe.device, "rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing": self.use_gradient_checkpointing,
@@ -80,7 +79,7 @@ class ZImageTrainingModule(DiffusionTrainingModule):
return loss return loss
def qwen_image_parser(): def z_image_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser = add_general_config(parser) parser = add_general_config(parser)
parser = add_image_size_config(parser) parser = add_image_size_config(parser)
@@ -89,7 +88,7 @@ def qwen_image_parser():
if __name__ == "__main__": if __name__ == "__main__":
parser = qwen_image_parser() parser = z_image_parser()
args = parser.parse_args() args = parser.parse_args()
accelerator = accelerate.Accelerator( accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,

View File

@@ -12,7 +12,7 @@ pipe = ZImagePipeline.from_pretrained(
], ],
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
) )
pipe.load_lora(pipe.dit, "/models/train/Z-Image-Turbo_lora/epoch-4.safetensors") pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo_lora/epoch-4.safetensors")
prompt = "a dog" prompt = "a dog"
image = pipe(prompt=prompt, seed=42, rand_device="cuda") image = pipe(prompt=prompt, seed=42, rand_device="cuda")
image.save("image.jpg") image.save("image.jpg")