From d96709fb6abb5ba1be614c4e28d97c7d20ded7d4 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 22 Sep 2025 17:45:42 +0800 Subject: [PATCH] update --- .../model_training/lora/Qwen-Image-DPO.sh | 25 +++ examples/qwen_image/model_training/train.py | 51 ++++- .../qwen_image/model_training/train_dpo.py | 184 ------------------ .../validate_lora/Qwen-Image-DPO.py | 19 ++ 4 files changed, 88 insertions(+), 191 deletions(-) create mode 100644 examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh delete mode 100644 examples/qwen_image/model_training/train_dpo.py create mode 100644 examples/qwen_image/model_training/validate_lora/Qwen-Image-DPO.py diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh b/examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh new file mode 100644 index 0000000..15fe434 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh @@ -0,0 +1,25 @@ +# dataset format: +# { +# "image": "path/to/win_image.png", # win image +# "lose_image": "path/to/lose_image.png", # lose image +# "prompt": "a photo of ...", +# } +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/dpo.jsonl \ + --data_file_keys "image,lose_image" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_DPO_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --task dpo \ + --beta_dpo 2500 \ + --find_unused_parameters \ No newline at end of file diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 3918e8e..a49b32e 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -1,5 +1,4 @@ -import torch, os, json -from diffsynth import load_state_dict +import torch, os from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.pipelines.flux_image_new import ControlNetInput from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task @@ -7,7 +6,6 @@ from diffsynth.trainers.unified_dataset import UnifiedDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" - class QwenImageTrainingModule(DiffusionTrainingModule): def __init__( self, @@ -20,6 +18,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): extra_inputs=None, enable_fp8_training=False, task="sft", + beta_dpo=1000., ): super().__init__() # Load models @@ -40,8 +39,9 @@ class QwenImageTrainingModule(DiffusionTrainingModule): self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] self.task = task + self.lora_base_model = lora_base_model + self.beta_dpo = beta_dpo - def forward_preprocess(self, data): # CFG-sensitive parameters inputs_posi = {"prompt": data["prompt"]} @@ -81,9 +81,44 @@ class QwenImageTrainingModule(DiffusionTrainingModule): for unit in self.pipe.units: inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) return {**inputs_shared, **inputs_posi} - - - def forward(self, data, inputs=None, return_inputs=False, **kwargs): + + def forward_dpo(self, data, accelerator=None): + # Loss DPO: -logσ(−β(diff_policy − diff_ref)) + # Prepare inputs + win_data = {key: data[key] for key in ["prompt", "image"]} + lose_data = {"prompt": None, "image": data["lose_image"]} + inputs_win = self.forward_preprocess(win_data) + inputs_lose = self.forward_preprocess(lose_data) + inputs_lose.update({key: inputs_win[key] for key in ["prompt", "prompt_emb", "prompt_emb_mask"]}) + inputs_win.pop('noise') + inputs_lose.pop('noise') + models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} + # sample timestep and noise + timestep = self.pipe.sample_timestep() + noise = torch.rand_like(inputs_win["latents"]) + # compute diff_policy = loss_win - loss_lose + loss_win = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win) + loss_lose = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose) + diff_policy = loss_win - loss_lose + # compute diff_ref + if self.lora_base_model is not None: + self.disable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit) + # load the original model weights + with torch.no_grad(): + loss_win_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win) + loss_lose_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose) + diff_ref = loss_win_ref - loss_lose_ref + self.enable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit) + else: + # TODO: may support full model training + raise NotImplementedError("DPO with full model training is not supported yet.") + # compute loss + loss = -1. * torch.nn.functional.logsigmoid(self.beta_dpo * (diff_ref - diff_policy)).mean() + return loss + + def forward(self, data, inputs=None, return_inputs=False, accelerator=None, **kwargs): + if self.task == "dpo": + return self.forward_dpo(data, accelerator=accelerator) # Inputs if inputs is None: inputs = self.forward_preprocess(data) @@ -137,11 +172,13 @@ if __name__ == "__main__": extra_inputs=args.extra_inputs, enable_fp8_training=args.enable_fp8_training, task=args.task, + beta_dpo=args.beta_dpo, ) model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) launcher_map = { "sft": launch_training_task, "data_process": launch_data_process_task, "direct_distill": launch_training_task, + "dpo": launch_training_task, } launcher_map[args.task](dataset, model, model_logger, args=args) diff --git a/examples/qwen_image/model_training/train_dpo.py b/examples/qwen_image/model_training/train_dpo.py deleted file mode 100644 index a49b32e..0000000 --- a/examples/qwen_image/model_training/train_dpo.py +++ /dev/null @@ -1,184 +0,0 @@ -import torch, os -from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig -from diffsynth.pipelines.flux_image_new import ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task -from diffsynth.trainers.unified_dataset import UnifiedDataset -os.environ["TOKENIZERS_PARALLELISM"] = "false" - - -class QwenImageTrainingModule(DiffusionTrainingModule): - def __init__( - self, - model_paths=None, model_id_with_origin_paths=None, - tokenizer_path=None, processor_path=None, - trainable_models=None, - lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, - use_gradient_checkpointing=True, - use_gradient_checkpointing_offload=False, - extra_inputs=None, - enable_fp8_training=False, - task="sft", - beta_dpo=1000., - ): - super().__init__() - # Load models - model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training) - tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) - processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) - self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) - - # Training mode - self.switch_pipe_to_training_mode( - self.pipe, trainable_models, - lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, - enable_fp8_training=enable_fp8_training, - ) - - # Store other configs - self.use_gradient_checkpointing = use_gradient_checkpointing - self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload - self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] - self.task = task - self.lora_base_model = lora_base_model - self.beta_dpo = beta_dpo - - def forward_preprocess(self, data): - # CFG-sensitive parameters - inputs_posi = {"prompt": data["prompt"]} - inputs_nega = {"negative_prompt": ""} - - # CFG-unsensitive parameters - inputs_shared = { - # Assume you are using this pipeline for inference, - # please fill in the input parameters. - "input_image": data["image"], - "height": data["image"].size[1], - "width": data["image"].size[0], - # Please do not modify the following parameters - # unless you clearly know what this will cause. - "cfg_scale": 1, - "rand_device": self.pipe.device, - "use_gradient_checkpointing": self.use_gradient_checkpointing, - "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, - "edit_image_auto_resize": True, - } - - # Extra inputs - controlnet_input, blockwise_controlnet_input = {}, {} - for extra_input in self.extra_inputs: - if extra_input.startswith("blockwise_controlnet_"): - blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input] - elif extra_input.startswith("controlnet_"): - controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input] - else: - inputs_shared[extra_input] = data[extra_input] - if len(controlnet_input) > 0: - inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)] - if len(blockwise_controlnet_input) > 0: - inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)] - - # Pipeline units will automatically process the input parameters. - for unit in self.pipe.units: - inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) - return {**inputs_shared, **inputs_posi} - - def forward_dpo(self, data, accelerator=None): - # Loss DPO: -logσ(−β(diff_policy − diff_ref)) - # Prepare inputs - win_data = {key: data[key] for key in ["prompt", "image"]} - lose_data = {"prompt": None, "image": data["lose_image"]} - inputs_win = self.forward_preprocess(win_data) - inputs_lose = self.forward_preprocess(lose_data) - inputs_lose.update({key: inputs_win[key] for key in ["prompt", "prompt_emb", "prompt_emb_mask"]}) - inputs_win.pop('noise') - inputs_lose.pop('noise') - models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} - # sample timestep and noise - timestep = self.pipe.sample_timestep() - noise = torch.rand_like(inputs_win["latents"]) - # compute diff_policy = loss_win - loss_lose - loss_win = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win) - loss_lose = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose) - diff_policy = loss_win - loss_lose - # compute diff_ref - if self.lora_base_model is not None: - self.disable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit) - # load the original model weights - with torch.no_grad(): - loss_win_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win) - loss_lose_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose) - diff_ref = loss_win_ref - loss_lose_ref - self.enable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit) - else: - # TODO: may support full model training - raise NotImplementedError("DPO with full model training is not supported yet.") - # compute loss - loss = -1. * torch.nn.functional.logsigmoid(self.beta_dpo * (diff_ref - diff_policy)).mean() - return loss - - def forward(self, data, inputs=None, return_inputs=False, accelerator=None, **kwargs): - if self.task == "dpo": - return self.forward_dpo(data, accelerator=accelerator) - # Inputs - if inputs is None: - inputs = self.forward_preprocess(data) - else: - inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype) - if return_inputs: return inputs - - # Loss - if self.task == "sft": - models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} - loss = self.pipe.training_loss(**models, **inputs) - elif self.task == "data_process": - loss = inputs - elif self.task == "direct_distill": - loss = self.pipe.direct_distill_loss(**inputs) - else: - raise NotImplementedError(f"Unsupported task: {self.task}.") - return loss - - - -if __name__ == "__main__": - parser = qwen_image_parser() - args = parser.parse_args() - dataset = UnifiedDataset( - base_path=args.dataset_base_path, - metadata_path=args.dataset_metadata_path, - repeat=args.dataset_repeat, - data_file_keys=args.data_file_keys.split(","), - main_data_operator=UnifiedDataset.default_image_operator( - base_path=args.dataset_base_path, - max_pixels=args.max_pixels, - height=args.height, - width=args.width, - height_division_factor=16, - width_division_factor=16, - ) - ) - model = QwenImageTrainingModule( - model_paths=args.model_paths, - model_id_with_origin_paths=args.model_id_with_origin_paths, - tokenizer_path=args.tokenizer_path, - processor_path=args.processor_path, - trainable_models=args.trainable_models, - lora_base_model=args.lora_base_model, - lora_target_modules=args.lora_target_modules, - lora_rank=args.lora_rank, - lora_checkpoint=args.lora_checkpoint, - use_gradient_checkpointing=args.use_gradient_checkpointing, - use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, - extra_inputs=args.extra_inputs, - enable_fp8_training=args.enable_fp8_training, - task=args.task, - beta_dpo=args.beta_dpo, - ) - model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) - launcher_map = { - "sft": launch_training_task, - "data_process": launch_data_process_task, - "direct_distill": launch_training_task, - "dpo": launch_training_task, - } - launcher_map[args.task](dataset, model, model_logger, args=args) diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-DPO.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-DPO.py new file mode 100644 index 0000000..51501c8 --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-DPO.py @@ -0,0 +1,19 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image_DPO_lora/epoch-4.safetensors") +prompt = "黑板上写着“群起效尤,心灵手巧”,字的颜色分别是 “群”: 橙色、“起”: 黑色、“效”: 蓝色、“尤”: 绿色、“心”: 紫色、“灵”: 粉色、“手”: 红色、“巧”: 白色" +for seed in range(0, 5): + image = pipe(prompt, seed=seed) + image.save(f"image_dpo_{seed}.jpg")