mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
support dpo training
This commit is contained in:
@@ -396,6 +396,15 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
param.data = param.to(upcast_dtype)
|
||||
return model
|
||||
|
||||
def disable_all_lora_layers(self, model):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'enable_adapters'):
|
||||
module.enable_adapters(False)
|
||||
|
||||
def enable_all_lora_layers(self, model):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'enable_adapters'):
|
||||
module.enable_adapters(True)
|
||||
|
||||
def mapping_lora_state_dict(self, state_dict):
|
||||
new_state_dict = {}
|
||||
@@ -554,9 +563,9 @@ def launch_training_task(
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
if dataset.load_from_cache:
|
||||
loss = model({}, inputs=data)
|
||||
loss = model({}, inputs=data, accelerator=accelerator)
|
||||
else:
|
||||
loss = model(data)
|
||||
loss = model(data, accelerator=accelerator)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(accelerator, model, save_steps)
|
||||
@@ -690,4 +699,5 @@ def qwen_image_parser():
|
||||
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
||||
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
|
||||
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||
parser.add_argument("--beta_dpo", type=float, default=1000, help="hyperparameter beta for DPO loss, only used when task is dpo.")
|
||||
return parser
|
||||
|
||||
Reference in New Issue
Block a user