support dpo training

This commit is contained in:
mi804
2025-09-22 10:14:17 +08:00
parent b0abdaffb4
commit bf7b339efb
7 changed files with 213 additions and 6 deletions

View File

@@ -153,6 +153,19 @@ class BasePipeline(torch.nn.Module):
latents_next = scheduler.step(noise_pred, timestep, latents)
return latents_next
def sample_timestep(self):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
return timestep
def training_loss_minimum(self, noise, timestep, **inputs):
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
noise_pred = self.model_fn(**inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.scheduler.training_weight(timestep)
return loss
@dataclass