support dpo training

This commit is contained in:
mi804
2025-09-22 10:14:17 +08:00
parent b0abdaffb4
commit bf7b339efb
7 changed files with 213 additions and 6 deletions

View File

@@ -525,7 +525,7 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
return split_result
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
if pipe.text_encoder is not None:
if pipe.text_encoder is not None and prompt is not None:
prompt = [prompt]
# If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit
if edit_image is None:

View File

@@ -396,6 +396,15 @@ class DiffusionTrainingModule(torch.nn.Module):
param.data = param.to(upcast_dtype)
return model
def disable_all_lora_layers(self, model):
for name, module in model.named_modules():
if hasattr(module, 'enable_adapters'):
module.enable_adapters(False)
def enable_all_lora_layers(self, model):
for name, module in model.named_modules():
if hasattr(module, 'enable_adapters'):
module.enable_adapters(True)
def mapping_lora_state_dict(self, state_dict):
new_state_dict = {}
@@ -554,9 +563,9 @@ def launch_training_task(
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
loss = model({}, inputs=data, accelerator=accelerator)
else:
loss = model(data)
loss = model(data, accelerator=accelerator)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps)
@@ -690,4 +699,5 @@ def qwen_image_parser():
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
parser.add_argument("--beta_dpo", type=float, default=1000, help="hyperparameter beta for DPO loss, only used when task is dpo.")
return parser

View File

@@ -153,6 +153,19 @@ class BasePipeline(torch.nn.Module):
latents_next = scheduler.step(noise_pred, timestep, latents)
return latents_next
def sample_timestep(self):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
return timestep
def training_loss_minimum(self, noise, timestep, **inputs):
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.scheduler.training_weight(timestep)
return loss
@dataclass