From 6737dbfc9f48b50312101f317211a70587d376ff Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 13 Oct 2025 16:39:57 +0800 Subject: [PATCH] support dpo --- diffsynth/trainers/utils.py | 58 +++++++++++++++++++++ examples/qwen_image/model_training/train.py | 42 ++++++++------- 2 files changed, 82 insertions(+), 18 deletions(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 3262d15..ed14b19 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -475,6 +475,64 @@ class DiffusionTrainingModule(torch.nn.Module): if len(load_result[1]) > 0: print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") setattr(pipe, lora_base_model, model) + + def disable_all_lora_layers(self, model): + for name, module in model.named_modules(): + if hasattr(module, 'enable_adapters'): + module.enable_adapters(False) + + def enable_all_lora_layers(self, model): + for name, module in model.named_modules(): + if hasattr(module, 'enable_adapters'): + module.enable_adapters(True) + + +class DPOLoss: + def __init__(self, beta=2500): + self.beta = beta + + def sample_timestep(self, pipe): + timestep_id = torch.randint(0, pipe.scheduler.num_train_timesteps, (1,)) + timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) + return timestep + + def training_loss_minimum(self, pipe, noise, timestep, **inputs): + inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) + training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) + noise_pred = pipe.model_fn(**inputs, timestep=timestep) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * pipe.scheduler.training_weight(timestep) + return loss + + def loss(self, model, data): + # Loss DPO: -logσ(−β(diff_policy − diff_ref)) + # Prepare inputs + win_data = {key: data[key] for key in ["prompt", "image"]} + lose_data = {"prompt": data["prompt"], "image": data["lose_image"]} + inputs_win = model.forward_preprocess(win_data) + inputs_lose = model.forward_preprocess(lose_data) + inputs_win.pop('noise') + inputs_lose.pop('noise') + models = {name: getattr(model.pipe, name) for name in model.pipe.in_iteration_models} + # sample timestep and noise + timestep = self.sample_timestep(model.pipe) + noise = torch.rand_like(inputs_win["latents"]) + # compute diff_policy = loss_win - loss_lose + loss_win = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win) + loss_lose = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose) + diff_policy = loss_win - loss_lose + # compute diff_ref + # TODO: may support full model training + model.disable_all_lora_layers(model.pipe.dit) + # load the original model weights + with torch.no_grad(): + loss_win_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win) + loss_lose_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose) + diff_ref = loss_win_ref - loss_lose_ref + model.enable_all_lora_layers(model.pipe.dit) + # compute loss + loss = -1. * torch.nn.functional.logsigmoid(self.beta * (diff_ref - diff_policy)).mean() + return loss class ModelLogger: diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index f39c11c..0035d45 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -2,7 +2,7 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.pipelines.flux_image_new import ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task, DPOLoss from diffsynth.trainers.unified_dataset import UnifiedDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -84,24 +84,29 @@ class QwenImageTrainingModule(DiffusionTrainingModule): def forward(self, data, inputs=None, return_inputs=False): - # Inputs - if inputs is None: - inputs = self.forward_preprocess(data) + # DPO (DPO requires a special training loss) + if self.task == "dpo": + loss = DPOLoss().loss(self, data) + return loss else: - inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) - if return_inputs: return inputs - - # Loss - if self.task == "sft": - models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} - loss = self.pipe.training_loss(**models, **inputs) - elif self.task == "data_process": - loss = inputs - elif self.task == "direct_distill": - loss = self.pipe.direct_distill_loss(**inputs) - else: - raise NotImplementedError(f"Unsupported task: {self.task}.") - return loss + # Inputs + if inputs is None: + inputs = self.forward_preprocess(data) + else: + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + if return_inputs: return inputs + + # Loss + if self.task == "sft": + models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} + loss = self.pipe.training_loss(**models, **inputs) + elif self.task == "data_process": + loss = inputs + elif self.task == "direct_distill": + loss = self.pipe.direct_distill_loss(**inputs) + else: + raise NotImplementedError(f"Unsupported task: {self.task}.") + return loss @@ -143,5 +148,6 @@ if __name__ == "__main__": "sft": launch_training_task, "data_process": launch_data_process_task, "direct_distill": launch_training_task, + "dpo": launch_training_task, } launcher_map[args.task](dataset, model, model_logger, args=args)