diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index 0216ae6..891ae0e 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -12,6 +12,7 @@ class FlowMatchScheduler(): "FLUX.2": FlowMatchScheduler.set_timesteps_flux2, "Z-Image": FlowMatchScheduler.set_timesteps_z_image, }.get(template, FlowMatchScheduler.set_timesteps_flux) + self.num_train_timesteps = 1000 @staticmethod def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None): @@ -111,11 +112,12 @@ class FlowMatchScheduler(): timesteps = sigmas * num_train_timesteps return sigmas, timesteps - def set_training_weight(self, num_inference_steps): + def set_training_weight(self): + steps = 1000 x = self.timesteps - y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) + y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2) y_shifted = y - y.min() - bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) + bsmntw_weighing = y_shifted * (steps / y_shifted.sum()) self.linear_timesteps_weights = bsmntw_weighing def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs): @@ -125,7 +127,7 @@ class FlowMatchScheduler(): **kwargs, ) if training: - self.set_training_weight(num_inference_steps) + self.set_training_weight() self.training = True else: self.training = False diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index 2e9d390..9a65b1e 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -3,8 +3,8 @@ import torch def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): - max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * pipe.scheduler.num_train_timesteps) - min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * pipe.scheduler.num_train_timesteps) + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) diff --git a/examples/z_image/model_training/lora/Z-Image-Turbo.sh b/examples/z_image/model_training/lora/Z-Image-Turbo.sh index 0563422..a00d57e 100644 --- a/examples/z_image/model_training/lora/Z-Image-Turbo.sh +++ b/examples/z_image/model_training/lora/Z-Image-Turbo.sh @@ -9,7 +9,7 @@ accelerate launch examples/z_image/model_training/train.py \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/Z-Image-Turbo_lora" \ --lora_base_model "dit" \ - --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ --lora_rank 32 \ --use_gradient_checkpointing \ --dataset_num_workers 8 diff --git a/examples/z_image/model_training/train.py b/examples/z_image/model_training/train.py index 912c98f..adedf05 100644 --- a/examples/z_image/model_training/train.py +++ b/examples/z_image/model_training/train.py @@ -62,7 +62,6 @@ class ZImageTrainingModule(DiffusionTrainingModule): "width": data["image"].size[0], # Please do not modify the following parameters # unless you clearly know what this will cause. - "embedded_guidance": 1.0, "cfg_scale": 1, "rand_device": self.pipe.device, "use_gradient_checkpointing": self.use_gradient_checkpointing, @@ -80,7 +79,7 @@ class ZImageTrainingModule(DiffusionTrainingModule): return loss -def qwen_image_parser(): +def z_image_parser(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = add_general_config(parser) parser = add_image_size_config(parser) @@ -89,7 +88,7 @@ def qwen_image_parser(): if __name__ == "__main__": - parser = qwen_image_parser() + parser = z_image_parser() args = parser.parse_args() accelerator = accelerate.Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, diff --git a/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py b/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py index 7164741..0400a00 100644 --- a/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py +++ b/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py @@ -12,7 +12,7 @@ pipe = ZImagePipeline.from_pretrained( ], tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), ) -pipe.load_lora(pipe.dit, "/models/train/Z-Image-Turbo_lora/epoch-4.safetensors") +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo_lora/epoch-4.safetensors") prompt = "a dog" image = pipe(prompt=prompt, seed=42, rand_device="cuda") image.save("image.jpg")