From 675ae5e91f0dfc1901e673ea9a8c9b81f6d55c7f Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Sun, 30 Nov 2025 15:22:39 +0800 Subject: [PATCH] refactor scheduler --- diffsynth/diffusion/flow_match.py | 257 ++++++++++++++++------------- diffsynth/pipelines/flux2_image.py | 10 +- diffsynth/pipelines/flux_image.py | 2 +- diffsynth/pipelines/qwen_image.py | 2 +- diffsynth/pipelines/wan_video.py | 4 +- diffsynth/pipelines/z_image.py | 2 +- 6 files changed, 150 insertions(+), 127 deletions(-) diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index 11ee0f4..0216ae6 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -1,129 +1,76 @@ import torch, math +from typing_extensions import Literal class FlowMatchScheduler(): - def __init__( - self, - num_inference_steps=100, - num_train_timesteps=1000, - shift=3.0, - sigma_max=1.0, - sigma_min=0.003/1.002, - inverse_timesteps=False, - extra_one_step=False, - reverse_sigmas=False, - exponential_shift=False, - exponential_shift_mu=None, - shift_terminal=None, - ): - self.num_train_timesteps = num_train_timesteps - self.shift = shift - self.sigma_max = sigma_max - self.sigma_min = sigma_min - self.inverse_timesteps = inverse_timesteps - self.extra_one_step = extra_one_step - self.reverse_sigmas = reverse_sigmas - self.exponential_shift = exponential_shift - self.exponential_shift_mu = exponential_shift_mu - self.shift_terminal = shift_terminal - self.set_timesteps(num_inference_steps) + def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"): + self.set_timesteps_fn = { + "FLUX.1": FlowMatchScheduler.set_timesteps_flux, + "Wan": FlowMatchScheduler.set_timesteps_wan, + "Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image, + "FLUX.2": FlowMatchScheduler.set_timesteps_flux2, + "Z-Image": FlowMatchScheduler.set_timesteps_z_image, + }.get(template, FlowMatchScheduler.set_timesteps_flux) - - def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, exponential_shift_mu=None): - if shift is not None: - self.shift = shift - sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength - if self.extra_one_step: - self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] - else: - self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) - if self.inverse_timesteps: - self.sigmas = torch.flip(self.sigmas, dims=[0]) - if self.exponential_shift: - if exponential_shift_mu is not None: - mu = exponential_shift_mu - elif dynamic_shift_len is not None: - mu = self.calculate_shift(dynamic_shift_len) - else: - mu = self.exponential_shift_mu - self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1)) - else: - self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) - if self.shift_terminal is not None: - one_minus_z = 1 - self.sigmas - scale_factor = one_minus_z[-1] / (1 - self.shift_terminal) - self.sigmas = 1 - (one_minus_z / scale_factor) - if self.reverse_sigmas: - self.sigmas = 1 - self.sigmas - self.timesteps = self.sigmas * self.num_train_timesteps - if training: - x = self.timesteps - y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) - y_shifted = y - y.min() - bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) - self.linear_timesteps_weights = bsmntw_weighing - self.training = True - else: - self.training = False - - - def step(self, model_output, timestep, sample, to_final=False, **kwargs): - if isinstance(timestep, torch.Tensor): - timestep = timestep.cpu() - timestep_id = torch.argmin((self.timesteps - timestep).abs()) - sigma = self.sigmas[timestep_id] - if to_final or timestep_id + 1 >= len(self.timesteps): - sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 - else: - sigma_ = self.sigmas[timestep_id + 1] - prev_sample = sample + model_output * (sigma_ - sigma) - return prev_sample + @staticmethod + def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None): + sigma_min = 0.003/1.002 + sigma_max = 1.0 + shift = 3 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps - - def return_to_timestep(self, timestep, sample, sample_stablized): - if isinstance(timestep, torch.Tensor): - timestep = timestep.cpu() - timestep_id = torch.argmin((self.timesteps - timestep).abs()) - sigma = self.sigmas[timestep_id] - model_output = (sample - sample_stablized) / sigma - return model_output + @staticmethod + def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None): + sigma_min = 0.0 + sigma_max = 1.0 + shift = 5 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps - - def add_noise(self, original_samples, noise, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.cpu() - timestep_id = torch.argmin((self.timesteps - timestep).abs()) - sigma = self.sigmas[timestep_id] - sample = (1 - sigma) * original_samples + sigma * noise - return sample - - - def training_target(self, sample, noise, timestep): - target = noise - sample - return target - - - def training_weight(self, timestep): - timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) - weights = self.linear_timesteps_weights[timestep_id] - return weights - - - def calculate_shift( - self, - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 8192, - base_shift: float = 0.5, - max_shift: float = 0.9, - ): + @staticmethod + def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu - - def compute_empirical_mu(self, image_seq_len: int, num_steps: int) -> float: + + @staticmethod + def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None): + sigma_min = 0.0 + sigma_max = 1.0 + num_train_timesteps = 1000 + shift_terminal = 0.02 + # Sigmas + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + # Mu + if exponential_shift_mu is not None: + mu = exponential_shift_mu + elif dynamic_shift_len is not None: + mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len) + else: + mu = 0.8 + sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) + # Shift terminal + one_minus_z = 1 - sigmas + scale_factor = one_minus_z[-1] / (1 - shift_terminal) + sigmas = 1 - (one_minus_z / scale_factor) + # Timesteps + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def compute_empirical_mu(image_seq_len, num_steps): a1, b1 = 8.73809524e-05, 1.89833333 a2, b2 = 0.00016927, 0.45666666 @@ -138,4 +85,84 @@ class FlowMatchScheduler(): b = m_200 - 200.0 * a mu = a * num_steps + b - return float(mu) \ No newline at end of file + return float(mu) + + @staticmethod + def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None): + sigma_min = 1 / num_inference_steps + sigma_max = 1.0 + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) + mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps) + sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + @staticmethod + def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None): + sigma_min = 0.0 + sigma_max = 1.0 + shift = 3 if shift is None else shift + num_train_timesteps = 1000 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * num_train_timesteps + return sigmas, timesteps + + def set_training_weight(self, num_inference_steps): + x = self.timesteps + y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) + self.linear_timesteps_weights = bsmntw_weighing + + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs): + self.sigmas, self.timesteps = self.set_timesteps_fn( + num_inference_steps=num_inference_steps, + denoising_strength=denoising_strength, + **kwargs, + ) + if training: + self.set_training_weight(num_inference_steps) + self.training = True + else: + self.training = False + + def step(self, model_output, timestep, sample, to_final=False, **kwargs): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + if to_final or timestep_id + 1 >= len(self.timesteps): + sigma_ = 0 + else: + sigma_ = self.sigmas[timestep_id + 1] + prev_sample = sample + model_output * (sigma_ - sigma) + return prev_sample + + def return_to_timestep(self, timestep, sample, sample_stablized): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + model_output = (sample - sample_stablized) / sigma + return model_output + + def add_noise(self, original_samples, noise, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + sample = (1 - sigma) * original_samples + sigma * noise + return sample + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + def training_weight(self, timestep): + timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) + weights = self.linear_timesteps_weights[timestep_id] + return weights diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 630667b..8b00469 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -23,7 +23,7 @@ class Flux2ImagePipeline(BasePipeline): device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, ) - self.scheduler = FlowMatchScheduler() + self.scheduler = FlowMatchScheduler("FLUX.2") self.text_encoder: Flux2TextEncoder = None self.dit: Flux2DiT = None self.vae: Flux2VAE = None @@ -86,6 +86,8 @@ class Flux2ImagePipeline(BasePipeline): # Progress bar progress_bar_cmd = tqdm, ): + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16) + # Parameters inputs_posi = { "prompt": prompt, @@ -103,12 +105,6 @@ class Flux2ImagePipeline(BasePipeline): for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) - # using dynamic shift Scheduler - self.scheduler.exponential_shift = True - self.scheduler.sigma_min = 1 / num_inference_steps - mu = self.scheduler.compute_empirical_mu(inputs_shared["latents"].shape[1], num_inference_steps) - self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, exponential_shift_mu=mu) - # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 2ef2617..1ee5635 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -60,7 +60,7 @@ class FluxImagePipeline(BasePipeline): device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, ) - self.scheduler = FlowMatchScheduler() + self.scheduler = FlowMatchScheduler("FLUX.1") self.tokenizer_1: CLIPTokenizer = None self.tokenizer_2: T5TokenizerFast = None self.text_encoder_1: FluxTextEncoderClip = None diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 03d58cf..fc6581e 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -24,7 +24,7 @@ class QwenImagePipeline(BasePipeline): ) from transformers import Qwen2Tokenizer, Qwen2VLProcessor - self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02) + self.scheduler = FlowMatchScheduler("Qwen-Image") self.text_encoder: QwenImageTextEncoder = None self.dit: QwenImageDiT = None self.vae: QwenImageVAE = None diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 0c9ce6f..fa43db1 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -35,7 +35,7 @@ class WanVideoPipeline(BasePipeline): device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1 ) - self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) + self.scheduler = FlowMatchScheduler("Wan") self.tokenizer: HuggingfaceTokenizer = None self.audio_processor: Wav2Vec2Processor = None self.text_encoder: WanTextEncoder = None @@ -283,7 +283,7 @@ class WanVideoPipeline(BasePipeline): models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): # Switch DiT if necessary - if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: + if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2: self.load_models_to_device(self.in_iteration_models_2) models["dit"] = self.dit2 models["vace"] = self.vace2 diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index b1ee420..b6ad72c 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -23,7 +23,7 @@ class ZImagePipeline(BasePipeline): device=device, torch_dtype=torch_dtype, height_division_factor=16, width_division_factor=16, ) - self.scheduler = FlowMatchScheduler() + self.scheduler = FlowMatchScheduler("Z-Image") self.text_encoder: ZImageTextEncoder = None self.dit: ZImageDiT = None self.vae_encoder: FluxVAEEncoder = None