support dpo

This commit is contained in:
Artiprocher
2025-10-13 16:39:57 +08:00
parent 0a1c172a00
commit 6737dbfc9f
2 changed files with 82 additions and 18 deletions

View File

@@ -475,6 +475,64 @@ class DiffusionTrainingModule(torch.nn.Module):
if len(load_result[1]) > 0:
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
setattr(pipe, lora_base_model, model)
def disable_all_lora_layers(self, model):
for name, module in model.named_modules():
if hasattr(module, 'enable_adapters'):
module.enable_adapters(False)
def enable_all_lora_layers(self, model):
for name, module in model.named_modules():
if hasattr(module, 'enable_adapters'):
module.enable_adapters(True)
class DPOLoss:
def __init__(self, beta=2500):
self.beta = beta
def sample_timestep(self, pipe):
timestep_id = torch.randint(0, pipe.scheduler.num_train_timesteps, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
return timestep
def training_loss_minimum(self, pipe, noise, timestep, **inputs):
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
noise_pred = pipe.model_fn(**inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
def loss(self, model, data):
# Loss DPO: -logσ(−β(diff_policy diff_ref))
# Prepare inputs
win_data = {key: data[key] for key in ["prompt", "image"]}
lose_data = {"prompt": data["prompt"], "image": data["lose_image"]}
inputs_win = model.forward_preprocess(win_data)
inputs_lose = model.forward_preprocess(lose_data)
inputs_win.pop('noise')
inputs_lose.pop('noise')
models = {name: getattr(model.pipe, name) for name in model.pipe.in_iteration_models}
# sample timestep and noise
timestep = self.sample_timestep(model.pipe)
noise = torch.rand_like(inputs_win["latents"])
# compute diff_policy = loss_win - loss_lose
loss_win = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win)
loss_lose = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose)
diff_policy = loss_win - loss_lose
# compute diff_ref
# TODO: may support full model training
model.disable_all_lora_layers(model.pipe.dit)
# load the original model weights
with torch.no_grad():
loss_win_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win)
loss_lose_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose)
diff_ref = loss_win_ref - loss_lose_ref
model.enable_all_lora_layers(model.pipe.dit)
# compute loss
loss = -1. * torch.nn.functional.logsigmoid(self.beta * (diff_ref - diff_policy)).mean()
return loss
class ModelLogger: