update wan input params

This commit is contained in:
Artiprocher
2025-02-28 19:43:18 +08:00
parent 6a92b08244
commit 3a75026176
3 changed files with 14 additions and 3 deletions

View File

@@ -15,7 +15,9 @@ class FlowMatchScheduler():
self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
if shift is not None:
self.shift = shift
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]