mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
270 lines
12 KiB
Python
270 lines
12 KiB
Python
import torch, math
|
|
from typing import Literal
|
|
|
|
|
|
class DDIMScheduler:
|
|
|
|
def __init__(
|
|
self,
|
|
num_train_timesteps: int = 1000,
|
|
beta_start: float = 0.00085,
|
|
beta_end: float = 0.012,
|
|
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear",
|
|
clip_sample: bool = False,
|
|
set_alpha_to_one: bool = False,
|
|
steps_offset: int = 1,
|
|
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
|
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
|
|
rescale_betas_zero_snr: bool = False,
|
|
):
|
|
self.num_train_timesteps = num_train_timesteps
|
|
self.beta_start = beta_start
|
|
self.beta_end = beta_end
|
|
self.beta_schedule = beta_schedule
|
|
self.clip_sample = clip_sample
|
|
self.set_alpha_to_one = set_alpha_to_one
|
|
self.steps_offset = steps_offset
|
|
self.prediction_type = prediction_type
|
|
self.timestep_spacing = timestep_spacing
|
|
|
|
# Compute betas
|
|
if beta_schedule == "linear":
|
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
|
elif beta_schedule == "scaled_linear":
|
|
# SD 1.5 specific: sqrt-linear interpolation
|
|
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
|
elif beta_schedule == "squaredcos_cap_v2":
|
|
self.betas = self._betas_for_alpha_bar(num_train_timesteps)
|
|
else:
|
|
raise ValueError(f"Unsupported beta_schedule: {beta_schedule}")
|
|
|
|
# Rescale for zero SNR
|
|
if rescale_betas_zero_snr:
|
|
self.betas = self._rescale_zero_terminal_snr(self.betas)
|
|
|
|
self.alphas = 1.0 - self.betas
|
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
|
|
|
# For the final step, there is no previous alphas_cumprod
|
|
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
|
|
|
# standard deviation of the initial noise distribution
|
|
self.init_noise_sigma = 1.0
|
|
|
|
# Setable values (will be populated by set_timesteps)
|
|
self.num_inference_steps = None
|
|
self.timesteps = torch.from_numpy(self._default_timesteps().astype("int64"))
|
|
self.training = False
|
|
|
|
@staticmethod
|
|
def _betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta: float = 0.999) -> torch.Tensor:
|
|
"""Create beta schedule via cosine alpha_bar function."""
|
|
def alpha_bar_fn(t):
|
|
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
|
|
|
betas = []
|
|
for i in range(num_diffusion_timesteps):
|
|
t1 = i / num_diffusion_timesteps
|
|
t2 = (i + 1) / num_diffusion_timesteps
|
|
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
|
return torch.tensor(betas, dtype=torch.float32)
|
|
|
|
@staticmethod
|
|
def _rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
|
"""Rescale betas to have zero terminal SNR."""
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
|
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
|
alphas_bar = alphas_bar_sqrt ** 2
|
|
alphas = torch.cat([alphas_bar[1:], alphas_bar[:1]])
|
|
return 1 - alphas
|
|
|
|
def _default_timesteps(self):
|
|
"""Default timesteps before set_timesteps is called."""
|
|
import numpy as np
|
|
return np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64)
|
|
|
|
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
|
|
"""Compute the variance for the DDIM step."""
|
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
|
beta_prod_t = 1 - alpha_prod_t
|
|
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
|
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
|
return variance
|
|
|
|
def set_timesteps(self, num_inference_steps: int = 100, denoising_strength: float = 1.0, training: bool = False, **kwargs):
|
|
"""
|
|
Sets the discrete timesteps used for the diffusion chain.
|
|
Follows FlowMatchScheduler interface: (num_inference_steps, denoising_strength, training, **kwargs)
|
|
"""
|
|
import numpy as np
|
|
|
|
if denoising_strength != 1.0:
|
|
# For img2img: adjust effective steps
|
|
num_inference_steps = int(num_inference_steps * denoising_strength)
|
|
|
|
# Compute step ratio
|
|
if self.timestep_spacing == "leading":
|
|
# leading: arange * step_ratio, reverse, then add offset
|
|
step_ratio = self.num_train_timesteps // num_inference_steps
|
|
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
|
|
timesteps = timesteps + self.steps_offset
|
|
elif self.timestep_spacing == "trailing":
|
|
# trailing: timesteps = arange(num_steps, 0, -1) * step_ratio - 1
|
|
step_ratio = self.num_train_timesteps / num_inference_steps
|
|
timesteps = (np.arange(num_inference_steps, 0, -1) * step_ratio - 1).round()[::-1]
|
|
elif self.timestep_spacing == "linspace":
|
|
# linspace: evenly spaced from num_train_timesteps - 1 to 0
|
|
timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps).round()[::-1]
|
|
else:
|
|
raise ValueError(f"Unsupported timestep_spacing: {self.timestep_spacing}")
|
|
|
|
# Clamp timesteps to valid range [0, num_train_timesteps - 1]
|
|
timesteps = np.clip(timesteps, 0, self.num_train_timesteps - 1)
|
|
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.int64)
|
|
self.num_inference_steps = num_inference_steps
|
|
|
|
if training:
|
|
self.set_training_weight()
|
|
self.training = True
|
|
else:
|
|
self.training = False
|
|
|
|
def set_training_weight(self):
|
|
"""Set timestep weights for training (similar to FlowMatchScheduler)."""
|
|
steps = 1000
|
|
x = self.timesteps
|
|
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
|
y_shifted = y - y.min()
|
|
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
|
if len(self.timesteps) != 1000:
|
|
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
|
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
|
self.linear_timesteps_weights = bsmntw_weighing
|
|
|
|
def step(self, model_output, timestep, sample, to_final: bool = False, eta: float = 0.0, **kwargs):
|
|
"""
|
|
DDIM step function.
|
|
Follows FlowMatchScheduler interface: step(model_output, timestep, sample, to_final=False)
|
|
|
|
For SD 1.5, prediction_type="epsilon" and eta=0.0 (deterministic DDIM).
|
|
"""
|
|
if isinstance(timestep, torch.Tensor):
|
|
timestep = timestep.cpu()
|
|
if timestep.dim() == 0:
|
|
timestep = timestep.item()
|
|
elif timestep.dim() == 1:
|
|
timestep = timestep[0].item()
|
|
|
|
# Ensure timestep is int
|
|
timestep = int(timestep)
|
|
|
|
# Find the index of the current timestep
|
|
timestep_id = torch.argmin((self.timesteps - timestep).abs()).item()
|
|
|
|
if timestep_id + 1 >= len(self.timesteps):
|
|
prev_timestep = -1
|
|
else:
|
|
prev_timestep = self.timesteps[timestep_id + 1].item()
|
|
|
|
# Get alphas
|
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
|
|
|
alpha_prod_t = alpha_prod_t.to(device=sample.device, dtype=sample.dtype)
|
|
alpha_prod_t_prev = alpha_prod_t_prev.to(device=sample.device, dtype=sample.dtype)
|
|
|
|
beta_prod_t = 1 - alpha_prod_t
|
|
|
|
# Compute predicted original sample (x_0)
|
|
if self.prediction_type == "epsilon":
|
|
pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
|
elif self.prediction_type == "sample":
|
|
pred_original_sample = model_output
|
|
elif self.prediction_type == "v_prediction":
|
|
pred_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
|
else:
|
|
raise ValueError(f"Unsupported prediction_type: {self.prediction_type}")
|
|
|
|
# Clip sample if needed
|
|
if self.clip_sample:
|
|
pred_original_sample = pred_original_sample.clamp(-1.0, 1.0)
|
|
|
|
# Compute predicted noise (re-derived from x_0)
|
|
pred_epsilon = (sample - alpha_prod_t.sqrt() * pred_original_sample) / beta_prod_t.sqrt()
|
|
|
|
# DDIM formula: prev_sample = sqrt(alpha_prev) * x0 + sqrt(1 - alpha_prev) * epsilon
|
|
prev_sample = alpha_prod_t_prev.sqrt() * pred_original_sample + (1 - alpha_prod_t_prev).sqrt() * pred_epsilon
|
|
|
|
# Add variance noise if eta > 0 (DDIM: eta=0, DDPM: eta=1)
|
|
if eta > 0:
|
|
variance = self._get_variance(timestep, prev_timestep)
|
|
variance = variance.to(device=sample.device, dtype=sample.dtype)
|
|
std_dev_t = eta * variance.sqrt()
|
|
device = sample.device
|
|
noise = torch.randn_like(sample)
|
|
prev_sample = prev_sample + std_dev_t * noise
|
|
|
|
return prev_sample
|
|
|
|
def add_noise(self, original_samples, noise, timestep):
|
|
"""Add noise to original samples (forward diffusion).
|
|
Follows FlowMatchScheduler interface: add_noise(original_samples, noise, timestep)
|
|
"""
|
|
if isinstance(timestep, torch.Tensor):
|
|
timestep = timestep.cpu()
|
|
if timestep.dim() == 0:
|
|
timestep = timestep.item()
|
|
elif timestep.dim() == 1:
|
|
timestep = timestep[0].item()
|
|
|
|
timestep = int(timestep)
|
|
# Defensive clamp: ensure timestep is within valid range
|
|
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
|
|
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
|
|
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
|
|
|
|
sqrt_alpha_prod = sqrt_alpha_prod.to(device=original_samples.device, dtype=original_samples.dtype)
|
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.to(device=original_samples.device, dtype=original_samples.dtype)
|
|
|
|
# Handle broadcasting for batch timesteps
|
|
while sqrt_alpha_prod.dim() < original_samples.dim():
|
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
|
|
|
sample = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
|
return sample
|
|
|
|
def training_target(self, sample, noise, timestep):
|
|
"""Return the training target for the given prediction type."""
|
|
if isinstance(timestep, torch.Tensor):
|
|
timestep = timestep.cpu()
|
|
if timestep.dim() == 0:
|
|
timestep = timestep.item()
|
|
elif timestep.dim() == 1:
|
|
timestep = timestep[0].item()
|
|
timestep = int(timestep)
|
|
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
|
|
if self.prediction_type == "epsilon":
|
|
return noise
|
|
elif self.prediction_type == "v_prediction":
|
|
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
|
|
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
|
|
return sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
|
elif self.prediction_type == "sample":
|
|
return sample
|
|
else:
|
|
raise ValueError(f"Unsupported prediction_type: {self.prediction_type}")
|
|
|
|
def training_weight(self, timestep):
|
|
"""Return training weight for the given timestep."""
|
|
timestep = max(0, min(int(timestep), self.num_train_timesteps - 1))
|
|
timestep_tensor = torch.tensor(timestep, device=self.timesteps.device)
|
|
timestep_id = torch.argmin((self.timesteps - timestep_tensor).abs())
|
|
return self.linear_timesteps_weights[timestep_id]
|