refactor scheduler

This commit is contained in:
Artiprocher
2025-11-30 15:22:39 +08:00
parent 1a6fd69e6b
commit 675ae5e91f
6 changed files with 150 additions and 127 deletions

View File

@@ -35,7 +35,7 @@ class WanVideoPipeline(BasePipeline):
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
)
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
self.scheduler = FlowMatchScheduler("Wan")
self.tokenizer: HuggingfaceTokenizer = None
self.audio_processor: Wav2Vec2Processor = None
self.text_encoder: WanTextEncoder = None
@@ -283,7 +283,7 @@ class WanVideoPipeline(BasePipeline):
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# Switch DiT if necessary
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2:
self.load_models_to_device(self.in_iteration_models_2)
models["dit"] = self.dit2
models["vace"] = self.vace2