support stepvideo quantized

This commit is contained in:
Artiprocher
2025-02-17 19:43:47 +08:00
parent 3681adc5ac
commit f191353cf4
6 changed files with 63 additions and 5 deletions

View File

@@ -238,7 +238,7 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
self.fps_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timestep, resolution=None, nframe=None, fps=None):
hidden_dtype = next(self.timestep_embedder.parameters()).dtype
hidden_dtype = timestep.dtype
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)