wan direct distill

This commit is contained in:
Artiprocher
2025-11-19 15:46:37 +08:00
parent 453ca89046
commit 6ad8d73717
4 changed files with 44 additions and 5 deletions

View File

@@ -23,6 +23,7 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
def DirectDistillLoss(pipe: BasePipeline, **inputs):
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
pipe.scheduler.training = True
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)