mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
support dpo training
This commit is contained in:
@@ -153,6 +153,19 @@ class BasePipeline(torch.nn.Module):
|
||||
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||
return latents_next
|
||||
|
||||
def sample_timestep(self):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
return timestep
|
||||
|
||||
def training_loss_minimum(self, noise, timestep, **inputs):
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user