The previous .clamp(min=1e-6) on (sigma_ - sigma) flips the sign when
the denominator is negative (which is the typical case since sigmas
decrease monotonically). This would invert the target and cause
training divergence.
Use torch.sign(denom) * torch.clamp(denom.abs(), min=1e-6) instead,
which prevents division by zero while preserving the correct sign.