From fb892bd860d268e1bdad0fe7434a8e8c6ae2aa59 Mon Sep 17 00:00:00 2001 From: lzws <2538048363@qq.com> Date: Thu, 27 Nov 2025 19:50:15 +0800 Subject: [PATCH] using dynamic shift Scheduler in flux2 --- diffsynth/diffusion/flow_match.py | 17 +++++++++++++++++ diffsynth/pipelines/flux2_image.py | 9 ++++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index 83c6a82..11ee0f4 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -122,3 +122,20 @@ class FlowMatchScheduler(): b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu + + def compute_empirical_mu(self, image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) \ No newline at end of file diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 3757a3b..630667b 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -86,9 +86,6 @@ class Flux2ImagePipeline(BasePipeline): # Progress bar progress_bar_cmd = tqdm, ): - # Scheduler - self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) - # Parameters inputs_posi = { "prompt": prompt, @@ -106,6 +103,12 @@ class Flux2ImagePipeline(BasePipeline): for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # using dynamic shift Scheduler + self.scheduler.exponential_shift = True + self.scheduler.sigma_min = 1 / num_inference_steps + mu = self.scheduler.compute_empirical_mu(inputs_shared["latents"].shape[1], num_inference_steps) + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, exponential_shift_mu=mu) + # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models}