support stepvideo quantized

This commit is contained in:
Artiprocher
2025-02-17 19:43:47 +08:00
parent 3681adc5ac
commit f191353cf4
6 changed files with 63 additions and 5 deletions

View File

@@ -181,7 +181,7 @@ class StepVideoPipeline(BasePipeline):
# Denoise
self.load_models_to_device(["dit"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
# Inference