From b61131c69318b47fe96d796439b91f91074497d3 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 21 Jan 2026 15:44:15 +0800 Subject: [PATCH] improve flux2 training performance --- diffsynth/diffusion/flow_match.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index bb5fbc5..2d6b367 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -89,13 +89,18 @@ class FlowMatchScheduler(): return float(mu) @staticmethod - def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16): + def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None): sigma_min = 1 / num_inference_steps sigma_max = 1.0 num_train_timesteps = 1000 sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) - mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps) + if dynamic_shift_len is None: + # If you ask me why I set mu=0.8, + # I can only say that it yields better training results. + mu = 0.8 + else: + mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps) sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) timesteps = sigmas * num_train_timesteps return sigmas, timesteps