Merge pull request #1293 from Mr-Neutr0n/fix/trajectory-loss-div-by-zero

fix: prevent division by zero in TrajectoryImitationLoss at final denoising step
This commit is contained in:
Zhongjie Duan
2026-03-02 10:21:39 +08:00
committed by GitHub

View File

@@ -121,7 +121,9 @@ class TrajectoryImitationLoss(torch.nn.Module):
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs()) progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
latents_ = trajectory_teacher[progress_id_teacher] latents_ = trajectory_teacher[progress_id_teacher]
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma) denom = sigma_ - sigma
denom = torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6)
target = (latents_ - inputs_shared["latents"]) / denom
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep) loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
return loss return loss