This commit is contained in:
Artiprocher
2025-11-30 20:03:14 +08:00
parent 20cf2317e0
commit 9048d2e9d4
7 changed files with 91 additions and 3 deletions

View File

@@ -51,7 +51,7 @@ class TimestepEmbedder(nn.Module):
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
t_emb = self.mlp(t_freq.to(torch.bfloat16))
return t_emb