fix: prevent division by zero in trajectory imitation loss at last step

This commit is contained in:
Mr-Neutr0n
2026-02-11 19:51:25 +05:30
parent 1b47e1dc22
commit 0e6976a0ae

View File

@@ -91,7 +91,7 @@ class TrajectoryImitationLoss(torch.nn.Module):
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
latents_ = trajectory_teacher[progress_id_teacher]
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma).clamp(min=1e-6)
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
return loss