mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
refactor scheduler
This commit is contained in:
@@ -1,129 +1,76 @@
|
|||||||
import torch, math
|
import torch, math
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
class FlowMatchScheduler():
|
class FlowMatchScheduler():
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
|
||||||
self,
|
self.set_timesteps_fn = {
|
||||||
num_inference_steps=100,
|
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||||
num_train_timesteps=1000,
|
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||||
shift=3.0,
|
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||||
sigma_max=1.0,
|
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||||
sigma_min=0.003/1.002,
|
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||||
inverse_timesteps=False,
|
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||||
extra_one_step=False,
|
|
||||||
reverse_sigmas=False,
|
|
||||||
exponential_shift=False,
|
|
||||||
exponential_shift_mu=None,
|
|
||||||
shift_terminal=None,
|
|
||||||
):
|
|
||||||
self.num_train_timesteps = num_train_timesteps
|
|
||||||
self.shift = shift
|
|
||||||
self.sigma_max = sigma_max
|
|
||||||
self.sigma_min = sigma_min
|
|
||||||
self.inverse_timesteps = inverse_timesteps
|
|
||||||
self.extra_one_step = extra_one_step
|
|
||||||
self.reverse_sigmas = reverse_sigmas
|
|
||||||
self.exponential_shift = exponential_shift
|
|
||||||
self.exponential_shift_mu = exponential_shift_mu
|
|
||||||
self.shift_terminal = shift_terminal
|
|
||||||
self.set_timesteps(num_inference_steps)
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, exponential_shift_mu=None):
|
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
if shift is not None:
|
sigma_min = 0.003/1.002
|
||||||
self.shift = shift
|
sigma_max = 1.0
|
||||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
shift = 3 if shift is None else shift
|
||||||
if self.extra_one_step:
|
num_train_timesteps = 1000
|
||||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
else:
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
if self.inverse_timesteps:
|
timesteps = sigmas * num_train_timesteps
|
||||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
return sigmas, timesteps
|
||||||
if self.exponential_shift:
|
|
||||||
if exponential_shift_mu is not None:
|
|
||||||
mu = exponential_shift_mu
|
|
||||||
elif dynamic_shift_len is not None:
|
|
||||||
mu = self.calculate_shift(dynamic_shift_len)
|
|
||||||
else:
|
|
||||||
mu = self.exponential_shift_mu
|
|
||||||
self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
|
|
||||||
else:
|
|
||||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
|
||||||
if self.shift_terminal is not None:
|
|
||||||
one_minus_z = 1 - self.sigmas
|
|
||||||
scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
|
|
||||||
self.sigmas = 1 - (one_minus_z / scale_factor)
|
|
||||||
if self.reverse_sigmas:
|
|
||||||
self.sigmas = 1 - self.sigmas
|
|
||||||
self.timesteps = self.sigmas * self.num_train_timesteps
|
|
||||||
if training:
|
|
||||||
x = self.timesteps
|
|
||||||
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
|
||||||
y_shifted = y - y.min()
|
|
||||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
|
||||||
self.linear_timesteps_weights = bsmntw_weighing
|
|
||||||
self.training = True
|
|
||||||
else:
|
|
||||||
self.training = False
|
|
||||||
|
|
||||||
|
|
||||||
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
|
||||||
if isinstance(timestep, torch.Tensor):
|
|
||||||
timestep = timestep.cpu()
|
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
||||||
sigma = self.sigmas[timestep_id]
|
|
||||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
|
||||||
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
|
||||||
else:
|
|
||||||
sigma_ = self.sigmas[timestep_id + 1]
|
|
||||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
|
||||||
return prev_sample
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
if isinstance(timestep, torch.Tensor):
|
sigma_min = 0.0
|
||||||
timestep = timestep.cpu()
|
sigma_max = 1.0
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
shift = 5 if shift is None else shift
|
||||||
sigma = self.sigmas[timestep_id]
|
num_train_timesteps = 1000
|
||||||
model_output = (sample - sample_stablized) / sigma
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
return model_output
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def add_noise(self, original_samples, noise, timestep):
|
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
|
||||||
if isinstance(timestep, torch.Tensor):
|
|
||||||
timestep = timestep.cpu()
|
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
|
||||||
sigma = self.sigmas[timestep_id]
|
|
||||||
sample = (1 - sigma) * original_samples + sigma * noise
|
|
||||||
return sample
|
|
||||||
|
|
||||||
|
|
||||||
def training_target(self, sample, noise, timestep):
|
|
||||||
target = noise - sample
|
|
||||||
return target
|
|
||||||
|
|
||||||
|
|
||||||
def training_weight(self, timestep):
|
|
||||||
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
|
||||||
weights = self.linear_timesteps_weights[timestep_id]
|
|
||||||
return weights
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_shift(
|
|
||||||
self,
|
|
||||||
image_seq_len,
|
|
||||||
base_seq_len: int = 256,
|
|
||||||
max_seq_len: int = 8192,
|
|
||||||
base_shift: float = 0.5,
|
|
||||||
max_shift: float = 0.9,
|
|
||||||
):
|
|
||||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||||
b = base_shift - m * base_seq_len
|
b = base_shift - m * base_seq_len
|
||||||
mu = image_seq_len * m + b
|
mu = image_seq_len * m + b
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
def compute_empirical_mu(self, image_seq_len: int, num_steps: int) -> float:
|
@staticmethod
|
||||||
|
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
shift_terminal = 0.02
|
||||||
|
# Sigmas
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
# Mu
|
||||||
|
if exponential_shift_mu is not None:
|
||||||
|
mu = exponential_shift_mu
|
||||||
|
elif dynamic_shift_len is not None:
|
||||||
|
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
|
||||||
|
else:
|
||||||
|
mu = 0.8
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
# Shift terminal
|
||||||
|
one_minus_z = 1 - sigmas
|
||||||
|
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
||||||
|
sigmas = 1 - (one_minus_z / scale_factor)
|
||||||
|
# Timesteps
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_empirical_mu(image_seq_len, num_steps):
|
||||||
a1, b1 = 8.73809524e-05, 1.89833333
|
a1, b1 = 8.73809524e-05, 1.89833333
|
||||||
a2, b2 = 0.00016927, 0.45666666
|
a2, b2 = 0.00016927, 0.45666666
|
||||||
|
|
||||||
@@ -138,4 +85,84 @@ class FlowMatchScheduler():
|
|||||||
b = m_200 - 200.0 * a
|
b = m_200 - 200.0 * a
|
||||||
mu = a * num_steps + b
|
mu = a * num_steps + b
|
||||||
|
|
||||||
return float(mu)
|
return float(mu)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
|
||||||
|
sigma_min = 1 / num_inference_steps
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||||
|
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
shift = 3 if shift is None else shift
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
def set_training_weight(self, num_inference_steps):
|
||||||
|
x = self.timesteps
|
||||||
|
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
||||||
|
y_shifted = y - y.min()
|
||||||
|
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
||||||
|
self.linear_timesteps_weights = bsmntw_weighing
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||||
|
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
denoising_strength=denoising_strength,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if training:
|
||||||
|
self.set_training_weight(num_inference_steps)
|
||||||
|
self.training = True
|
||||||
|
else:
|
||||||
|
self.training = False
|
||||||
|
|
||||||
|
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||||
|
sigma_ = 0
|
||||||
|
else:
|
||||||
|
sigma_ = self.sigmas[timestep_id + 1]
|
||||||
|
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||||
|
return prev_sample
|
||||||
|
|
||||||
|
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
model_output = (sample - sample_stablized) / sigma
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def add_noise(self, original_samples, noise, timestep):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.cpu()
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||||
|
sigma = self.sigmas[timestep_id]
|
||||||
|
sample = (1 - sigma) * original_samples + sigma * noise
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def training_target(self, sample, noise, timestep):
|
||||||
|
target = noise - sample
|
||||||
|
return target
|
||||||
|
|
||||||
|
def training_weight(self, timestep):
|
||||||
|
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||||
|
weights = self.linear_timesteps_weights[timestep_id]
|
||||||
|
return weights
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
device=device, torch_dtype=torch_dtype,
|
device=device, torch_dtype=torch_dtype,
|
||||||
height_division_factor=16, width_division_factor=16,
|
height_division_factor=16, width_division_factor=16,
|
||||||
)
|
)
|
||||||
self.scheduler = FlowMatchScheduler()
|
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||||
self.text_encoder: Flux2TextEncoder = None
|
self.text_encoder: Flux2TextEncoder = None
|
||||||
self.dit: Flux2DiT = None
|
self.dit: Flux2DiT = None
|
||||||
self.vae: Flux2VAE = None
|
self.vae: Flux2VAE = None
|
||||||
@@ -86,6 +86,8 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
# Progress bar
|
# Progress bar
|
||||||
progress_bar_cmd = tqdm,
|
progress_bar_cmd = tqdm,
|
||||||
):
|
):
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
inputs_posi = {
|
inputs_posi = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
@@ -103,12 +105,6 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
# using dynamic shift Scheduler
|
|
||||||
self.scheduler.exponential_shift = True
|
|
||||||
self.scheduler.sigma_min = 1 / num_inference_steps
|
|
||||||
mu = self.scheduler.compute_empirical_mu(inputs_shared["latents"].shape[1], num_inference_steps)
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, exponential_shift_mu=mu)
|
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(self.in_iteration_models)
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
device=device, torch_dtype=torch_dtype,
|
device=device, torch_dtype=torch_dtype,
|
||||||
height_division_factor=16, width_division_factor=16,
|
height_division_factor=16, width_division_factor=16,
|
||||||
)
|
)
|
||||||
self.scheduler = FlowMatchScheduler()
|
self.scheduler = FlowMatchScheduler("FLUX.1")
|
||||||
self.tokenizer_1: CLIPTokenizer = None
|
self.tokenizer_1: CLIPTokenizer = None
|
||||||
self.tokenizer_2: T5TokenizerFast = None
|
self.tokenizer_2: T5TokenizerFast = None
|
||||||
self.text_encoder_1: FluxTextEncoderClip = None
|
self.text_encoder_1: FluxTextEncoderClip = None
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
)
|
)
|
||||||
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
|
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
|
||||||
|
|
||||||
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
|
self.scheduler = FlowMatchScheduler("Qwen-Image")
|
||||||
self.text_encoder: QwenImageTextEncoder = None
|
self.text_encoder: QwenImageTextEncoder = None
|
||||||
self.dit: QwenImageDiT = None
|
self.dit: QwenImageDiT = None
|
||||||
self.vae: QwenImageVAE = None
|
self.vae: QwenImageVAE = None
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
device=device, torch_dtype=torch_dtype,
|
device=device, torch_dtype=torch_dtype,
|
||||||
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
||||||
)
|
)
|
||||||
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
self.scheduler = FlowMatchScheduler("Wan")
|
||||||
self.tokenizer: HuggingfaceTokenizer = None
|
self.tokenizer: HuggingfaceTokenizer = None
|
||||||
self.audio_processor: Wav2Vec2Processor = None
|
self.audio_processor: Wav2Vec2Processor = None
|
||||||
self.text_encoder: WanTextEncoder = None
|
self.text_encoder: WanTextEncoder = None
|
||||||
@@ -283,7 +283,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
# Switch DiT if necessary
|
# Switch DiT if necessary
|
||||||
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
|
if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2:
|
||||||
self.load_models_to_device(self.in_iteration_models_2)
|
self.load_models_to_device(self.in_iteration_models_2)
|
||||||
models["dit"] = self.dit2
|
models["dit"] = self.dit2
|
||||||
models["vace"] = self.vace2
|
models["vace"] = self.vace2
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class ZImagePipeline(BasePipeline):
|
|||||||
device=device, torch_dtype=torch_dtype,
|
device=device, torch_dtype=torch_dtype,
|
||||||
height_division_factor=16, width_division_factor=16,
|
height_division_factor=16, width_division_factor=16,
|
||||||
)
|
)
|
||||||
self.scheduler = FlowMatchScheduler()
|
self.scheduler = FlowMatchScheduler("Z-Image")
|
||||||
self.text_encoder: ZImageTextEncoder = None
|
self.text_encoder: ZImageTextEncoder = None
|
||||||
self.dit: ZImageDiT = None
|
self.dit: ZImageDiT = None
|
||||||
self.vae_encoder: FluxVAEEncoder = None
|
self.vae_encoder: FluxVAEEncoder = None
|
||||||
|
|||||||
Reference in New Issue
Block a user