support dpo training

This commit is contained in:
mi804
2025-09-22 10:14:17 +08:00
parent b0abdaffb4
commit bf7b339efb
7 changed files with 213 additions and 6 deletions

View File

@@ -396,6 +396,15 @@ class DiffusionTrainingModule(torch.nn.Module):
param.data = param.to(upcast_dtype)
return model
def disable_all_lora_layers(self, model):
for name, module in model.named_modules():
if hasattr(module, 'enable_adapters'):
module.enable_adapters(False)
def enable_all_lora_layers(self, model):
for name, module in model.named_modules():
if hasattr(module, 'enable_adapters'):
module.enable_adapters(True)
def mapping_lora_state_dict(self, state_dict):
new_state_dict = {}
@@ -554,9 +563,9 @@ def launch_training_task(
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
loss = model({}, inputs=data, accelerator=accelerator)
else:
loss = model(data)
loss = model(data, accelerator=accelerator)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps)
@@ -690,4 +699,5 @@ def qwen_image_parser():
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
parser.add_argument("--beta_dpo", type=float, default=1000, help="hyperparameter beta for DPO loss, only used when task is dpo.")
return parser