diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index fd3cb76..bb5fbc5 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -101,7 +101,7 @@ class FlowMatchScheduler(): return sigmas, timesteps @staticmethod - def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None): + def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None): sigma_min = 0.0 sigma_max = 1.0 shift = 3 if shift is None else shift @@ -110,6 +110,11 @@ class FlowMatchScheduler(): sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) timesteps = sigmas * num_train_timesteps + if target_timesteps is not None: + target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device) + for timestep in target_timesteps: + timestep_id = torch.argmin((timesteps - timestep).abs()) + timesteps[timestep_id] = timestep return sigmas, timesteps def set_training_weight(self): @@ -118,6 +123,10 @@ class FlowMatchScheduler(): y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2) y_shifted = y - y.min() bsmntw_weighing = y_shifted * (steps / y_shifted.sum()) + if len(self.timesteps) != 1000: + # This is an empirical formula. + bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps) + bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1] self.linear_timesteps_weights = bsmntw_weighing def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs): diff --git a/diffsynth/diffusion/loss.py b/diffsynth/diffusion/loss.py index 9a65b1e..ae44bb6 100644 --- a/diffsynth/diffusion/loss.py +++ b/diffsynth/diffusion/loss.py @@ -31,3 +31,89 @@ def DirectDistillLoss(pipe: BasePipeline, **inputs): inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) return loss + + +class TrajectoryImitationLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.initialized = False + + def initialize(self, device): + import lpips # TODO: remove it + self.loss_fn = lpips.LPIPS(net='alex').to(device) + self.initialized = True + + def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + trajectory = [inputs_shared["latents"].clone()] + + pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) + + trajectory.append(inputs_shared["latents"].clone()) + return pipe.scheduler.timesteps, trajectory + + def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + loss = 0 + pipe.scheduler.set_timesteps(num_inference_steps, training=True) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + + progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs()) + inputs_shared["latents"] = trajectory_teacher[progress_id_teacher] + + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + + sigma = pipe.scheduler.sigmas[progress_id] + sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1] + if progress_id + 1 >= len(pipe.scheduler.timesteps): + latents_ = trajectory_teacher[-1] + else: + progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs()) + latents_ = trajectory_teacher[progress_id_teacher] + + target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma) + loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep) + return loss + + def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): + inputs_shared["latents"] = trajectory_teacher[0] + pipe.scheduler.set_timesteps(num_inference_steps) + models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} + for progress_id, timestep in enumerate(pipe.scheduler.timesteps): + timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) + noise_pred = pipe.cfg_guided_model_fn( + pipe.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) + + image_pred = pipe.vae_decoder(inputs_shared["latents"]) + image_real = pipe.vae_decoder(trajectory_teacher[-1]) + loss = self.loss_fn(image_pred.float(), image_real.float()) + return loss + + def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega): + if not self.initialized: + self.initialize(pipe.device) + with torch.no_grad(): + pipe.scheduler.set_timesteps(8) + timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2) + timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device) + loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) + loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) + loss = loss_1 + loss_2 + return loss diff --git a/examples/z_image/model_training/lora/Z-Image-Turbo.sh b/examples/z_image/model_training/lora/Z-Image-Turbo.sh index 4f539b4..a00d57e 100644 --- a/examples/z_image/model_training/lora/Z-Image-Turbo.sh +++ b/examples/z_image/model_training/lora/Z-Image-Turbo.sh @@ -13,27 +13,3 @@ accelerate launch examples/z_image/model_training/train.py \ --lora_rank 32 \ --use_gradient_checkpointing \ --dataset_num_workers 8 - - -# Z-Image-Turbo is a distilled model. -# After training, it loses its distillation-based acceleration capability, -# leading to degraded generation quality at fewer inference steps. -# This issue can be mitigated by using a pre-trained LoRA model to assist the training process. - -# accelerate launch examples/z_image/model_training/train.py \ -# --dataset_base_path data/example_image_dataset \ -# --dataset_metadata_path data/example_image_dataset/metadata.csv \ -# --max_pixels 1048576 \ -# --dataset_repeat 50 \ -# --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ -# --learning_rate 1e-4 \ -# --num_epochs 5 \ -# --remove_prefix_in_ckpt "pipe.dit." \ -# --output_path "./models/train/Z-Image-Turbo_lora" \ -# --lora_base_model "dit" \ -# --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ -# --lora_rank 32 \ -# --preset_lora_path "models/ostris/zimage_turbo_training_adapter/zimage_turbo_training_adapter_v1.safetensors" \ -# --preset_lora_model "dit" \ -# --use_gradient_checkpointing \ -# --dataset_num_workers 8 diff --git a/examples/z_image/model_training/special/differential_training/Z-Image-Turbo.sh b/examples/z_image/model_training/special/differential_training/Z-Image-Turbo.sh new file mode 100644 index 0000000..1751ec7 --- /dev/null +++ b/examples/z_image/model_training/special/differential_training/Z-Image-Turbo.sh @@ -0,0 +1,23 @@ +# Z-Image-Turbo is a distilled model. +# After training, it loses its distillation-based acceleration capability, +# leading to degraded generation quality at fewer inference steps. +# This issue can be mitigated by using a pre-trained LoRA model to assist the training process. +# https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter + +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Turbo_lora_differential" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --preset_lora_path "models/ostris/zimage_turbo_training_adapter/zimage_turbo_training_adapter_v1.safetensors" \ + --preset_lora_model "dit" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 diff --git a/examples/z_image/model_training/special/differential_training/validate.py b/examples/z_image/model_training/special/differential_training/validate.py new file mode 100644 index 0000000..ad45c22 --- /dev/null +++ b/examples/z_image/model_training/special/differential_training/validate.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo_lora_differential/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=42, rand_device="cuda") +image.save("image.jpg") diff --git a/examples/z_image/model_training/special/trajectory_imitation/Z-Image-Turbo.sh b/examples/z_image/model_training/special/trajectory_imitation/Z-Image-Turbo.sh new file mode 100644 index 0000000..c4ec8de --- /dev/null +++ b/examples/z_image/model_training/special/trajectory_imitation/Z-Image-Turbo.sh @@ -0,0 +1,18 @@ +accelerate launch examples/z_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Z-Image-Turbo_lora_distill" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \ + --lora_rank 32 \ + --lora_checkpoint "./models/train/Z-Image-Turbo_lora/epoch-4.safetensors" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --task "trajectory_imitation" \ + --save_steps 10 diff --git a/examples/z_image/model_training/special/trajectory_imitation/validate.py b/examples/z_image/model_training/special/trajectory_imitation/validate.py new file mode 100644 index 0000000..4306606 --- /dev/null +++ b/examples/z_image/model_training/special/trajectory_imitation/validate.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig +import torch + + +pipe = ZImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"), + ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "./models/train/Z-Image-Turbo_lora_distill/step-20.safetensors") +prompt = "a dog" +image = pipe(prompt=prompt, seed=42, rand_device="cuda") +image.save("image.jpg") diff --git a/examples/z_image/model_training/train.py b/examples/z_image/model_training/train.py index adedf05..b4c76e5 100644 --- a/examples/z_image/model_training/train.py +++ b/examples/z_image/model_training/train.py @@ -1,4 +1,4 @@ -import torch, os, argparse, accelerate +import torch, os, argparse, accelerate, copy from diffsynth.core import UnifiedDataset from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig from diffsynth.diffusion import * @@ -50,6 +50,13 @@ class ZImageTrainingModule(DiffusionTrainingModule): "direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), "direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi), } + if task == "trajectory_imitation": + # This is an experimental feature. + # We may remove it in the future. + self.loss_fn = TrajectoryImitationLoss() + self.task_to_loss["trajectory_imitation"] = self.loss_fn + self.pipe_teacher = copy.deepcopy(self.pipe) + self.pipe_teacher.requires_grad_(False) def get_pipeline_inputs(self, data): inputs_posi = {"prompt": data["prompt"]} @@ -67,6 +74,9 @@ class ZImageTrainingModule(DiffusionTrainingModule): "use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, } + if self.task == "trajectory_imitation": + inputs_shared["cfg_scale"] = 2 + inputs_shared["teacher"] = self.pipe_teacher inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared) return inputs_shared, inputs_posi, inputs_nega @@ -138,5 +148,6 @@ if __name__ == "__main__": "sft:train": launch_training_task, "direct_distill": launch_training_task, "direct_distill:train": launch_training_task, + "trajectory_imitation": launch_training_task, } launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)