Merge pull request #1025 from krahets/patch-1

Fix sinusoidal_embedding calculation for bf16 precision.
This commit is contained in:
Zhongjie Duan
2025-11-04 15:08:11 +08:00
committed by GitHub

View File

@@ -362,7 +362,7 @@ class WanModel(torch.nn.Module):
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
sinusoidal_embedding_1d(self.freq_dim, timestep).to(x.dtype))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)