From bf7b339efbf043f31e41d9aa4d6316a680e85cfa Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 22 Sep 2025 10:14:17 +0800 Subject: [PATCH] support dpo training --- diffsynth/pipelines/qwen_image.py | 2 +- diffsynth/trainers/utils.py | 14 +- diffsynth/utils/__init__.py | 13 ++ examples/flux/model_training/train.py | 2 +- examples/qwen_image/model_training/train.py | 2 +- .../qwen_image/model_training/train_dpo.py | 184 ++++++++++++++++++ examples/wanvideo/model_training/train.py | 2 +- 7 files changed, 213 insertions(+), 6 deletions(-) create mode 100644 examples/qwen_image/model_training/train_dpo.py diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 83ff290..c6aada9 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -525,7 +525,7 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit): return split_result def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict: - if pipe.text_encoder is not None: + if pipe.text_encoder is not None and prompt is not None: prompt = [prompt] # If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit if edit_image is None: diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 3262d15..025d690 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -396,6 +396,15 @@ class DiffusionTrainingModule(torch.nn.Module): param.data = param.to(upcast_dtype) return model + def disable_all_lora_layers(self, model): + for name, module in model.named_modules(): + if hasattr(module, 'enable_adapters'): + module.enable_adapters(False) + + def enable_all_lora_layers(self, model): + for name, module in model.named_modules(): + if hasattr(module, 'enable_adapters'): + module.enable_adapters(True) def mapping_lora_state_dict(self, state_dict): new_state_dict = {} @@ -554,9 +563,9 @@ def launch_training_task( with accelerator.accumulate(model): optimizer.zero_grad() if dataset.load_from_cache: - loss = model({}, inputs=data) + loss = model({}, inputs=data, accelerator=accelerator) else: - loss = model(data) + loss = model(data, accelerator=accelerator) accelerator.backward(loss) optimizer.step() model_logger.on_step_end(accelerator, model, save_steps) @@ -690,4 +699,5 @@ def qwen_image_parser(): parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.") parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.") + parser.add_argument("--beta_dpo", type=float, default=1000, help="hyperparameter beta for DPO loss, only used when task is dpo.") return parser diff --git a/diffsynth/utils/__init__.py b/diffsynth/utils/__init__.py index ec3c727..c74e9e4 100644 --- a/diffsynth/utils/__init__.py +++ b/diffsynth/utils/__init__.py @@ -153,6 +153,19 @@ class BasePipeline(torch.nn.Module): latents_next = scheduler.step(noise_pred, timestep, latents) return latents_next + def sample_timestep(self): + timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) + timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) + return timestep + + def training_loss_minimum(self, noise, timestep, **inputs): + inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep) + training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep) + noise_pred = self.model_fn(**inputs, timestep=timestep) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.scheduler.training_weight(timestep) + return loss + @dataclass diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index 4a82228..de12a62 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -75,7 +75,7 @@ class FluxTrainingModule(DiffusionTrainingModule): return {**inputs_shared, **inputs_posi} - def forward(self, data, inputs=None): + def forward(self, data, inputs=None, **kwargs): if inputs is None: inputs = self.forward_preprocess(data) models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} loss = self.pipe.training_loss(**models, **inputs) diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index f39c11c..3918e8e 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -83,7 +83,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): return {**inputs_shared, **inputs_posi} - def forward(self, data, inputs=None, return_inputs=False): + def forward(self, data, inputs=None, return_inputs=False, **kwargs): # Inputs if inputs is None: inputs = self.forward_preprocess(data) diff --git a/examples/qwen_image/model_training/train_dpo.py b/examples/qwen_image/model_training/train_dpo.py new file mode 100644 index 0000000..a49b32e --- /dev/null +++ b/examples/qwen_image/model_training/train_dpo.py @@ -0,0 +1,184 @@ +import torch, os +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.pipelines.flux_image_new import ControlNetInput +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task +from diffsynth.trainers.unified_dataset import UnifiedDataset +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +class QwenImageTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, processor_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + enable_fp8_training=False, + task="sft", + beta_dpo=1000., + ): + super().__init__() + # Load models + model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training) + tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) + self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + + # Training mode + self.switch_pipe_to_training_mode( + self.pipe, trainable_models, + lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, + enable_fp8_training=enable_fp8_training, + ) + + # Store other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.task = task + self.lora_base_model = lora_base_model + self.beta_dpo = beta_dpo + + def forward_preprocess(self, data): + # CFG-sensitive parameters + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + + # CFG-unsensitive parameters + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + "edit_image_auto_resize": True, + } + + # Extra inputs + controlnet_input, blockwise_controlnet_input = {}, {} + for extra_input in self.extra_inputs: + if extra_input.startswith("blockwise_controlnet_"): + blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input] + elif extra_input.startswith("controlnet_"): + controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input] + else: + inputs_shared[extra_input] = data[extra_input] + if len(controlnet_input) > 0: + inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)] + if len(blockwise_controlnet_input) > 0: + inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)] + + # Pipeline units will automatically process the input parameters. + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + return {**inputs_shared, **inputs_posi} + + def forward_dpo(self, data, accelerator=None): + # Loss DPO: -logσ(−β(diff_policy − diff_ref)) + # Prepare inputs + win_data = {key: data[key] for key in ["prompt", "image"]} + lose_data = {"prompt": None, "image": data["lose_image"]} + inputs_win = self.forward_preprocess(win_data) + inputs_lose = self.forward_preprocess(lose_data) + inputs_lose.update({key: inputs_win[key] for key in ["prompt", "prompt_emb", "prompt_emb_mask"]}) + inputs_win.pop('noise') + inputs_lose.pop('noise') + models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} + # sample timestep and noise + timestep = self.pipe.sample_timestep() + noise = torch.rand_like(inputs_win["latents"]) + # compute diff_policy = loss_win - loss_lose + loss_win = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win) + loss_lose = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose) + diff_policy = loss_win - loss_lose + # compute diff_ref + if self.lora_base_model is not None: + self.disable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit) + # load the original model weights + with torch.no_grad(): + loss_win_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win) + loss_lose_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose) + diff_ref = loss_win_ref - loss_lose_ref + self.enable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit) + else: + # TODO: may support full model training + raise NotImplementedError("DPO with full model training is not supported yet.") + # compute loss + loss = -1. * torch.nn.functional.logsigmoid(self.beta_dpo * (diff_ref - diff_policy)).mean() + return loss + + def forward(self, data, inputs=None, return_inputs=False, accelerator=None, **kwargs): + if self.task == "dpo": + return self.forward_dpo(data, accelerator=accelerator) + # Inputs + if inputs is None: + inputs = self.forward_preprocess(data) + else: + inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) + if return_inputs: return inputs + + # Loss + if self.task == "sft": + models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} + loss = self.pipe.training_loss(**models, **inputs) + elif self.task == "data_process": + loss = inputs + elif self.task == "direct_distill": + loss = self.pipe.direct_distill_loss(**inputs) + else: + raise NotImplementedError(f"Unsupported task: {self.task}.") + return loss + + + +if __name__ == "__main__": + parser = qwen_image_parser() + args = parser.parse_args() + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = QwenImageTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + processor_path=args.processor_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + enable_fp8_training=args.enable_fp8_training, + task=args.task, + beta_dpo=args.beta_dpo, + ) + model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) + launcher_map = { + "sft": launch_training_task, + "data_process": launch_data_process_task, + "direct_distill": launch_training_task, + "dpo": launch_training_task, + } + launcher_map[args.task](dataset, model, model_logger, args=args) diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 37494e7..c3096c1 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -82,7 +82,7 @@ class WanTrainingModule(DiffusionTrainingModule): return {**inputs_shared, **inputs_posi} - def forward(self, data, inputs=None): + def forward(self, data, inputs=None, **kwargs): if inputs is None: inputs = self.forward_preprocess(data) models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} loss = self.pipe.training_loss(**models, **inputs)