update sd training scripts

This commit is contained in:
Artiprocher
2026-04-24 14:30:09 +08:00
parent 5cdab9ed01
commit 3799bdc23a
23 changed files with 323 additions and 612 deletions

View File

@@ -902,7 +902,7 @@ mova_series = [
]
stable_diffusion_xl_series = [
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
"model_name": "stable_diffusion_xl_unet",
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
@@ -916,21 +916,21 @@ stable_diffusion_xl_series = [
},
},
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
"model_hash": "98cc34ccc5b54ae0e56bdea8688dcd5a",
"model_name": "stable_diffusion_xl_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter",
},
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors")
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
"model_name": "stable_diffusion_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "13115dd45a6e1c39860f91ab073b8a78",
"model_name": "stable_diffusion_xl_vae",
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",

View File

@@ -1,269 +1,107 @@
import torch, math
from typing import Literal
class DDIMScheduler:
class DDIMScheduler():
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear",
clip_sample: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 1,
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
rescale_betas_zero_snr: bool = False,
):
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
self.num_train_timesteps = num_train_timesteps
self.beta_start = beta_start
self.beta_end = beta_end
self.beta_schedule = beta_schedule
self.clip_sample = clip_sample
self.set_alpha_to_one = set_alpha_to_one
self.steps_offset = steps_offset
self.prediction_type = prediction_type
self.timestep_spacing = timestep_spacing
# Compute betas
if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# SD 1.5 specific: sqrt-linear interpolation
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
self.betas = self._betas_for_alpha_bar(num_train_timesteps)
if beta_schedule == "scaled_linear":
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
elif beta_schedule == "linear":
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
else:
raise ValueError(f"Unsupported beta_schedule: {beta_schedule}")
# Rescale for zero SNR
if rescale_betas_zero_snr:
self.betas = self._rescale_zero_terminal_snr(self.betas)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# For the final step, there is no previous alphas_cumprod
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# Setable values (will be populated by set_timesteps)
self.num_inference_steps = None
self.timesteps = torch.from_numpy(self._default_timesteps().astype("int64"))
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
if rescale_zero_terminal_snr:
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
self.alphas_cumprod = self.alphas_cumprod.tolist()
self.set_timesteps(10)
self.prediction_type = prediction_type
self.training = False
@staticmethod
def _betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta: float = 0.999) -> torch.Tensor:
"""Create beta schedule via cosine alpha_bar function."""
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
@staticmethod
def _rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""Rescale betas to have zero terminal SNR."""
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
def rescale_zero_terminal_snr(self, alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
alphas_bar = alphas_bar_sqrt ** 2
alphas = torch.cat([alphas_bar[1:], alphas_bar[:1]])
return 1 - alphas
def _default_timesteps(self):
"""Default timesteps before set_timesteps is called."""
import numpy as np
return np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
"""Compute the variance for the DDIM step."""
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
return alphas_bar
def set_timesteps(self, num_inference_steps: int = 100, denoising_strength: float = 1.0, training: bool = False, **kwargs):
"""
Sets the discrete timesteps used for the diffusion chain.
Follows FlowMatchScheduler interface: (num_inference_steps, denoising_strength, training, **kwargs)
"""
import numpy as np
if denoising_strength != 1.0:
# For img2img: adjust effective steps
num_inference_steps = int(num_inference_steps * denoising_strength)
# Compute step ratio
if self.timestep_spacing == "leading":
# leading: arange * step_ratio, reverse, then add offset
step_ratio = self.num_train_timesteps // num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
timesteps = timesteps + self.steps_offset
elif self.timestep_spacing == "trailing":
# trailing: timesteps = arange(num_steps, 0, -1) * step_ratio - 1
step_ratio = self.num_train_timesteps / num_inference_steps
timesteps = (np.arange(num_inference_steps, 0, -1) * step_ratio - 1).round()[::-1]
elif self.timestep_spacing == "linspace":
# linspace: evenly spaced from num_train_timesteps - 1 to 0
timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps).round()[::-1]
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, training=False, **kwargs):
# The timesteps are aligned to 999...0, which is different from other implementations,
# but I think this implementation is more reasonable in theory.
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
num_inference_steps = min(num_inference_steps, max_timestep + 1)
if num_inference_steps == 1:
self.timesteps = torch.Tensor([max_timestep])
else:
raise ValueError(f"Unsupported timestep_spacing: {self.timestep_spacing}")
step_length = max_timestep / (num_inference_steps - 1)
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
self.training = training
# Clamp timesteps to valid range [0, num_train_timesteps - 1]
timesteps = np.clip(timesteps, 0, self.num_train_timesteps - 1)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.int64)
self.num_inference_steps = num_inference_steps
if training:
self.set_training_weight()
self.training = True
else:
self.training = False
def set_training_weight(self):
"""Set timestep weights for training (similar to FlowMatchScheduler)."""
steps = 1000
x = self.timesteps
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
if len(self.timesteps) != 1000:
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
self.linear_timesteps_weights = bsmntw_weighing
def step(self, model_output, timestep, sample, to_final: bool = False, eta: float = 0.0, **kwargs):
"""
DDIM step function.
Follows FlowMatchScheduler interface: step(model_output, timestep, sample, to_final=False)
For SD 1.5, prediction_type="epsilon" and eta=0.0 (deterministic DDIM).
"""
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
if timestep.dim() == 0:
timestep = timestep.item()
elif timestep.dim() == 1:
timestep = timestep[0].item()
# Ensure timestep is int
timestep = int(timestep)
# Find the index of the current timestep
timestep_id = torch.argmin((self.timesteps - timestep).abs()).item()
if timestep_id + 1 >= len(self.timesteps):
prev_timestep = -1
else:
prev_timestep = self.timesteps[timestep_id + 1].item()
# Get alphas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t = alpha_prod_t.to(device=sample.device, dtype=sample.dtype)
alpha_prod_t_prev = alpha_prod_t_prev.to(device=sample.device, dtype=sample.dtype)
beta_prod_t = 1 - alpha_prod_t
# Compute predicted original sample (x_0)
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
if self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
elif self.prediction_type == "sample":
pred_original_sample = model_output
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
prev_sample = sample * weight_x + model_output * weight_e
elif self.prediction_type == "v_prediction":
pred_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
prev_sample = sample * weight_x + model_output * weight_e
else:
raise ValueError(f"Unsupported prediction_type: {self.prediction_type}")
# Clip sample if needed
if self.clip_sample:
pred_original_sample = pred_original_sample.clamp(-1.0, 1.0)
# Compute predicted noise (re-derived from x_0)
pred_epsilon = (sample - alpha_prod_t.sqrt() * pred_original_sample) / beta_prod_t.sqrt()
# DDIM formula: prev_sample = sqrt(alpha_prev) * x0 + sqrt(1 - alpha_prev) * epsilon
prev_sample = alpha_prod_t_prev.sqrt() * pred_original_sample + (1 - alpha_prod_t_prev).sqrt() * pred_epsilon
# Add variance noise if eta > 0 (DDIM: eta=0, DDPM: eta=1)
if eta > 0:
variance = self._get_variance(timestep, prev_timestep)
variance = variance.to(device=sample.device, dtype=sample.dtype)
std_dev_t = eta * variance.sqrt()
device = sample.device
noise = torch.randn_like(sample)
prev_sample = prev_sample + std_dev_t * noise
raise NotImplementedError(f"{self.prediction_type} is not implemented")
return prev_sample
def add_noise(self, original_samples, noise, timestep):
"""Add noise to original samples (forward diffusion).
Follows FlowMatchScheduler interface: add_noise(original_samples, noise, timestep)
"""
def step(self, model_output, timestep, sample, to_final=False):
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
if timestep.dim() == 0:
timestep = timestep.item()
elif timestep.dim() == 1:
timestep = timestep[0].item()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
if to_final or timestep_id + 1 >= len(self.timesteps):
alpha_prod_t_prev = 1.0
else:
timestep_prev = int(self.timesteps[timestep_id + 1])
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
timestep = int(timestep)
# Defensive clamp: ensure timestep is within valid range
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
sqrt_alpha_prod = sqrt_alpha_prod.to(device=original_samples.device, dtype=original_samples.dtype)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.to(device=original_samples.device, dtype=original_samples.dtype)
# Handle broadcasting for batch timesteps
while sqrt_alpha_prod.dim() < original_samples.dim():
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
sample = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return sample
def return_to_timestep(self, timestep, sample, sample_stablized):
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
return noise_pred
def add_noise(self, original_samples, noise, timestep):
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def training_target(self, sample, noise, timestep):
"""Return the training target for the given prediction type."""
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
if timestep.dim() == 0:
timestep = timestep.item()
elif timestep.dim() == 1:
timestep = timestep[0].item()
timestep = int(timestep)
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
if self.prediction_type == "epsilon":
return noise
elif self.prediction_type == "v_prediction":
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
return sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
elif self.prediction_type == "sample":
return sample
else:
raise ValueError(f"Unsupported prediction_type: {self.prediction_type}")
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return target
def training_weight(self, timestep):
"""Return training weight for the given timestep."""
timestep = max(0, min(int(timestep), self.num_train_timesteps - 1))
timestep_tensor = torch.tensor(timestep, device=self.timesteps.device)
timestep_id = torch.argmin((self.timesteps - timestep_tensor).abs())
return self.linear_timesteps_weights[timestep_id]
return 1.0

View File

@@ -196,19 +196,14 @@ class SDUnit_InputImageEmbedder(PipelineUnit):
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
if input_image is None:
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
if pipe.scheduler.training:
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample()
latents = noise * pipe.scheduler.init_noise_sigma
return {"latents": latents, "input_latents": input_latents}
else:
# Inference mode: VAE encode input image, add noise for initial latent
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample()
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
return {"latents": latents}

View File

@@ -49,6 +49,7 @@ class StableDiffusionXLPipeline(BasePipeline):
SDXLUnit_PromptEmbedder(),
SDXLUnit_NoiseInitializer(),
SDXLUnit_InputImageEmbedder(),
SDXLUnit_AddTimeIdsComputer(),
]
self.model_fn = model_fn_stable_diffusion_xl
self.compilable_models = ["unet"]
@@ -94,20 +95,11 @@ class StableDiffusionXLPipeline(BasePipeline):
seed: int = None,
rand_device: str = "cpu",
num_inference_steps: int = 50,
eta: float = 0.0,
guidance_rescale: float = 0.0,
original_size: tuple = None,
crops_coords_top_left: tuple = (0, 0),
target_size: tuple = None,
progress_bar_cmd=tqdm,
):
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Scheduler
self.scheduler.set_timesteps(
num_inference_steps, eta=eta,
)
self.scheduler.set_timesteps(num_inference_steps)
# 2. Three-dict input preparation
inputs_posi = {
@@ -121,9 +113,7 @@ class StableDiffusionXLPipeline(BasePipeline):
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"guidance_rescale": guidance_rescale,
"original_size": original_size,
"crops_coords_top_left": crops_coords_top_left,
"target_size": target_size,
"crops_coords_top_left": (0, 0),
}
# 3. Unit chain execution
@@ -132,18 +122,7 @@ class StableDiffusionXLPipeline(BasePipeline):
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Compute add_time_ids (micro-conditioning)
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size,
dtype=self.torch_dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
neg_add_time_ids = add_time_ids.clone()
inputs_posi["add_time_ids"] = add_time_ids
inputs_nega["add_time_ids"] = neg_add_time_ids
# 5. Denoise loop
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -183,21 +162,6 @@ class StableDiffusionXLPipeline(BasePipeline):
return image
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
# SDXL UNet doesn't have a config attribute, so we access add_embedding directly
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
# addition_time_embed_dim is the dimension of each time ID projection (256 for SDXL base)
addition_time_embed_dim = self.unet.add_time_proj.num_channels
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
f"but a vector of {passed_add_embed_dim} was created."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=self.device)
return add_time_ids
class SDXLUnit_ShapeChecker(PipelineUnit):
def __init__(self):
@@ -294,22 +258,51 @@ class SDXLUnit_InputImageEmbedder(PipelineUnit):
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
if input_image is None:
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
if pipe.scheduler.training:
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample()
latents = noise * pipe.scheduler.init_noise_sigma
return {"latents": latents, "input_latents": input_latents}
else:
# Inference mode: VAE encode input image, add noise for initial latent
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample()
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
return {"latents": latents}
class SDXLUnit_AddTimeIdsComputer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("add_time_ids",),
)
def _get_add_time_ids(self, pipe, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
addition_time_embed_dim = pipe.unet.add_time_proj.num_channels
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
f"but a vector of {passed_add_embed_dim} was created."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=pipe.device)
return add_time_ids
def process(self, pipe: StableDiffusionXLPipeline, height, width):
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
pipe, original_size, crops_coords_top_left, target_size,
dtype=pipe.torch_dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
return {"add_time_ids": add_time_ids}
def model_fn_stable_diffusion_xl(
unet: SDXLUNet2DConditionModel,
latents=None,