import torch, math from typing_extensions import Literal class FlowMatchScheduler(): def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"): self.set_timesteps_fn = { "FLUX.1": FlowMatchScheduler.set_timesteps_flux, "Wan": FlowMatchScheduler.set_timesteps_wan, "Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image, "FLUX.2": FlowMatchScheduler.set_timesteps_flux2, "Z-Image": FlowMatchScheduler.set_timesteps_z_image, }.get(template, FlowMatchScheduler.set_timesteps_flux) self.num_train_timesteps = 1000 @staticmethod def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None): sigma_min = 0.003/1.002 sigma_max = 1.0 shift = 3 if shift is None else shift num_train_timesteps = 1000 sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) timesteps = sigmas * num_train_timesteps return sigmas, timesteps @staticmethod def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None): sigma_min = 0.0 sigma_max = 1.0 shift = 5 if shift is None else shift num_train_timesteps = 1000 sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) timesteps = sigmas * num_train_timesteps return sigmas, timesteps @staticmethod def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu @staticmethod def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None): sigma_min = 0.0 sigma_max = 1.0 num_train_timesteps = 1000 shift_terminal = 0.02 # Sigmas sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] # Mu if exponential_shift_mu is not None: mu = exponential_shift_mu elif dynamic_shift_len is not None: mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len) else: mu = 0.8 sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) # Shift terminal one_minus_z = 1 - sigmas scale_factor = one_minus_z[-1] / (1 - shift_terminal) sigmas = 1 - (one_minus_z / scale_factor) # Timesteps timesteps = sigmas * num_train_timesteps return sigmas, timesteps @staticmethod def compute_empirical_mu(image_seq_len, num_steps): a1, b1 = 8.73809524e-05, 1.89833333 a2, b2 = 0.00016927, 0.45666666 if image_seq_len > 4300: mu = a2 * image_seq_len + b2 return float(mu) m_200 = a2 * image_seq_len + b2 m_10 = a1 * image_seq_len + b1 a = (m_200 - m_10) / 190.0 b = m_200 - 200.0 * a mu = a * num_steps + b return float(mu) @staticmethod def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None): sigma_min = 1 / num_inference_steps sigma_max = 1.0 num_train_timesteps = 1000 sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps) mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps) sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1)) timesteps = sigmas * num_train_timesteps return sigmas, timesteps @staticmethod def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None): sigma_min = 0.0 sigma_max = 1.0 shift = 3 if shift is None else shift num_train_timesteps = 1000 sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) timesteps = sigmas * num_train_timesteps return sigmas, timesteps def set_training_weight(self): steps = 1000 x = self.timesteps y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2) y_shifted = y - y.min() bsmntw_weighing = y_shifted * (steps / y_shifted.sum()) self.linear_timesteps_weights = bsmntw_weighing def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs): self.sigmas, self.timesteps = self.set_timesteps_fn( num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, **kwargs, ) if training: self.set_training_weight() self.training = True else: self.training = False def step(self, model_output, timestep, sample, to_final=False, **kwargs): if isinstance(timestep, torch.Tensor): timestep = timestep.cpu() timestep_id = torch.argmin((self.timesteps - timestep).abs()) sigma = self.sigmas[timestep_id] if to_final or timestep_id + 1 >= len(self.timesteps): sigma_ = 0 else: sigma_ = self.sigmas[timestep_id + 1] prev_sample = sample + model_output * (sigma_ - sigma) return prev_sample def return_to_timestep(self, timestep, sample, sample_stablized): if isinstance(timestep, torch.Tensor): timestep = timestep.cpu() timestep_id = torch.argmin((self.timesteps - timestep).abs()) sigma = self.sigmas[timestep_id] model_output = (sample - sample_stablized) / sigma return model_output def add_noise(self, original_samples, noise, timestep): if isinstance(timestep, torch.Tensor): timestep = timestep.cpu() timestep_id = torch.argmin((self.timesteps - timestep).abs()) sigma = self.sigmas[timestep_id] sample = (1 - sigma) * original_samples + sigma * noise return sample def training_target(self, sample, noise, timestep): target = noise - sample return target def training_weight(self, timestep): timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) weights = self.linear_timesteps_weights[timestep_id] return weights