mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
sd
This commit is contained in:
@@ -900,6 +900,53 @@ mova_series = [
|
||||
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
|
||||
},
|
||||
]
|
||||
stable_diffusion_xl_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
|
||||
"model_name": "stable_diffusion_xl_unet",
|
||||
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
|
||||
"model_hash": "98cc34ccc5b54ae0e56bdea8688dcd5a",
|
||||
"model_name": "stable_diffusion_xl_text_encoder",
|
||||
"model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "13115dd45a6e1c39860f91ab073b8a78",
|
||||
"model_name": "stable_diffusion_xl_vae",
|
||||
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
|
||||
"extra_kwargs": {"scaling_factor": 0.13025, "sample_size": 1024, "force_upcast": True},
|
||||
},
|
||||
]
|
||||
|
||||
stable_diffusion_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors")
|
||||
"model_hash": "ffd1737ae9df7fd43f5fbed653bdad67",
|
||||
"model_name": "stable_diffusion_text_encoder",
|
||||
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "f86d5683ed32433be8ca69969c67ba69",
|
||||
"model_name": "stable_diffusion_vae",
|
||||
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "025a4b86a84829399d89f613e580757b",
|
||||
"model_name": "stable_diffusion_unet",
|
||||
"model_class": "diffsynth.models.stable_diffusion_unet.UNet2DConditionModel",
|
||||
},
|
||||
]
|
||||
|
||||
joyai_image_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
|
||||
@@ -916,4 +963,4 @@ joyai_image_series = [
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + stable_diffusion_xl_series + stable_diffusion_series + joyai_image_series
|
||||
|
||||
@@ -295,6 +295,45 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.stable_diffusion_unet.UNet2DConditionModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.stable_diffusion_vae.StableDiffusionVAE": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.stable_diffusion_vae.Upsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.stable_diffusion_vae.Downsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
def QwenImageTextEncoder_Module_Map_Updater():
|
||||
|
||||
255
diffsynth/diffusion/ddim_scheduler.py
Normal file
255
diffsynth/diffusion/ddim_scheduler.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import torch, math
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class DDIMScheduler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085,
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear",
|
||||
clip_sample: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
steps_offset: int = 1,
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.beta_start = beta_start
|
||||
self.beta_end = beta_end
|
||||
self.beta_schedule = beta_schedule
|
||||
self.clip_sample = clip_sample
|
||||
self.set_alpha_to_one = set_alpha_to_one
|
||||
self.steps_offset = steps_offset
|
||||
self.prediction_type = prediction_type
|
||||
self.timestep_spacing = timestep_spacing
|
||||
|
||||
# Compute betas
|
||||
if beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# SD 1.5 specific: sqrt-linear interpolation
|
||||
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
self.betas = self._betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unsupported beta_schedule: {beta_schedule}")
|
||||
|
||||
# Rescale for zero SNR
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = self._rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# For the final step, there is no previous alphas_cumprod
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# Setable values (will be populated by set_timesteps)
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(self._default_timesteps().astype("int64"))
|
||||
self.training = False
|
||||
|
||||
@staticmethod
|
||||
def _betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta: float = 0.999) -> torch.Tensor:
|
||||
"""Create beta schedule via cosine alpha_bar function."""
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
@staticmethod
|
||||
def _rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""Rescale betas to have zero terminal SNR."""
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
alphas_bar = alphas_bar_sqrt ** 2
|
||||
alphas = torch.cat([alphas_bar[1:], alphas_bar[:1]])
|
||||
return 1 - alphas
|
||||
|
||||
def _default_timesteps(self):
|
||||
"""Default timesteps before set_timesteps is called."""
|
||||
import numpy as np
|
||||
return np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64)
|
||||
|
||||
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
|
||||
"""Compute the variance for the DDIM step."""
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = 100, denoising_strength: float = 1.0, training: bool = False, **kwargs):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain.
|
||||
Follows FlowMatchScheduler interface: (num_inference_steps, denoising_strength, training, **kwargs)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
if denoising_strength != 1.0:
|
||||
# For img2img: adjust effective steps
|
||||
num_inference_steps = int(num_inference_steps * denoising_strength)
|
||||
|
||||
# Compute step ratio
|
||||
if self.timestep_spacing == "leading":
|
||||
# leading: arange * step_ratio, reverse, then add offset
|
||||
step_ratio = self.num_train_timesteps // num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
|
||||
timesteps = timesteps + self.steps_offset
|
||||
elif self.timestep_spacing == "trailing":
|
||||
# trailing: timesteps = arange(num_steps, 0, -1) * step_ratio - 1
|
||||
step_ratio = self.num_train_timesteps / num_inference_steps
|
||||
timesteps = (np.arange(num_inference_steps, 0, -1) * step_ratio - 1).round()[::-1]
|
||||
elif self.timestep_spacing == "linspace":
|
||||
# linspace: evenly spaced from num_train_timesteps - 1 to 0
|
||||
timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps).round()[::-1]
|
||||
else:
|
||||
raise ValueError(f"Unsupported timestep_spacing: {self.timestep_spacing}")
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.int64)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if training:
|
||||
self.set_training_weight()
|
||||
self.training = True
|
||||
else:
|
||||
self.training = False
|
||||
|
||||
def set_training_weight(self):
|
||||
"""Set timestep weights for training (similar to FlowMatchScheduler)."""
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
||||
if len(self.timesteps) != 1000:
|
||||
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
||||
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final: bool = False, eta: float = 0.0, **kwargs):
|
||||
"""
|
||||
DDIM step function.
|
||||
Follows FlowMatchScheduler interface: step(model_output, timestep, sample, to_final=False)
|
||||
|
||||
For SD 1.5, prediction_type="epsilon" and eta=0.0 (deterministic DDIM).
|
||||
"""
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
if timestep.dim() == 0:
|
||||
timestep = timestep.item()
|
||||
elif timestep.dim() == 1:
|
||||
timestep = timestep[0].item()
|
||||
|
||||
# Ensure timestep is int
|
||||
timestep = int(timestep)
|
||||
|
||||
# Find the index of the current timestep
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs()).item()
|
||||
|
||||
if timestep_id + 1 >= len(self.timesteps):
|
||||
prev_timestep = -1
|
||||
else:
|
||||
prev_timestep = self.timesteps[timestep_id + 1].item()
|
||||
|
||||
# Get alphas
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
|
||||
alpha_prod_t = alpha_prod_t.to(device=sample.device, dtype=sample.dtype)
|
||||
alpha_prod_t_prev = alpha_prod_t_prev.to(device=sample.device, dtype=sample.dtype)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# Compute predicted original sample (x_0)
|
||||
if self.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
||||
elif self.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.prediction_type == "v_prediction":
|
||||
pred_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction_type: {self.prediction_type}")
|
||||
|
||||
# Clip sample if needed
|
||||
if self.clip_sample:
|
||||
pred_original_sample = pred_original_sample.clamp(-1.0, 1.0)
|
||||
|
||||
# Compute predicted noise (re-derived from x_0)
|
||||
pred_epsilon = (sample - alpha_prod_t.sqrt() * pred_original_sample) / beta_prod_t.sqrt()
|
||||
|
||||
# DDIM formula: prev_sample = sqrt(alpha_prev) * x0 + sqrt(1 - alpha_prev) * epsilon
|
||||
prev_sample = alpha_prod_t_prev.sqrt() * pred_original_sample + (1 - alpha_prod_t_prev).sqrt() * pred_epsilon
|
||||
|
||||
# Add variance noise if eta > 0 (DDIM: eta=0, DDPM: eta=1)
|
||||
if eta > 0:
|
||||
variance = self._get_variance(timestep, prev_timestep)
|
||||
variance = variance.to(device=sample.device, dtype=sample.dtype)
|
||||
std_dev_t = eta * variance.sqrt()
|
||||
device = sample.device
|
||||
noise = torch.randn_like(sample)
|
||||
prev_sample = prev_sample + std_dev_t * noise
|
||||
|
||||
return prev_sample
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
"""Add noise to original samples (forward diffusion).
|
||||
Follows FlowMatchScheduler interface: add_noise(original_samples, noise, timestep)
|
||||
"""
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
if timestep.dim() == 0:
|
||||
timestep = timestep.item()
|
||||
elif timestep.dim() == 1:
|
||||
timestep = timestep[0].item()
|
||||
|
||||
timestep = int(timestep)
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
|
||||
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
|
||||
# Handle broadcasting for batch timesteps
|
||||
while sqrt_alpha_prod.dim() < original_samples.dim():
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sample = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return sample
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
"""Return the training target for the given prediction type."""
|
||||
if self.prediction_type == "epsilon":
|
||||
return noise
|
||||
elif self.prediction_type == "v_prediction":
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
|
||||
return sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
elif self.prediction_type == "sample":
|
||||
return sample
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction_type: {self.prediction_type}")
|
||||
|
||||
def training_weight(self, timestep):
|
||||
"""Return training weight for the given timestep."""
|
||||
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||
return self.linear_timesteps_weights[timestep_id]
|
||||
78
diffsynth/models/stable_diffusion_text_encoder.py
Normal file
78
diffsynth/models/stable_diffusion_text_encoder.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import torch
|
||||
|
||||
|
||||
class SDTextEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_position_embeddings=77,
|
||||
vocab_size=49408,
|
||||
layer_norm_eps=1e-05,
|
||||
hidden_act="quick_gelu",
|
||||
initializer_factor=1.0,
|
||||
initializer_range=0.02,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
projection_dim=768,
|
||||
):
|
||||
super().__init__()
|
||||
from transformers import CLIPConfig, CLIPTextModel
|
||||
|
||||
config = CLIPConfig(
|
||||
text_config={
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size": intermediate_size,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"num_attention_heads": num_attention_heads,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
"vocab_size": vocab_size,
|
||||
"layer_norm_eps": layer_norm_eps,
|
||||
"hidden_act": hidden_act,
|
||||
"initializer_factor": initializer_factor,
|
||||
"initializer_range": initializer_range,
|
||||
"bos_token_id": bos_token_id,
|
||||
"eos_token_id": eos_token_id,
|
||||
"pad_token_id": pad_token_id,
|
||||
"projection_dim": projection_dim,
|
||||
"dropout": 0.0,
|
||||
},
|
||||
vision_config={
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size": intermediate_size,
|
||||
"num_hidden_layers": num_hidden_layers,
|
||||
"num_attention_heads": num_attention_heads,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
"layer_norm_eps": layer_norm_eps,
|
||||
"hidden_act": hidden_act,
|
||||
"initializer_factor": initializer_factor,
|
||||
"initializer_range": initializer_range,
|
||||
"projection_dim": projection_dim,
|
||||
},
|
||||
projection_dim=projection_dim,
|
||||
)
|
||||
self.model = CLIPTextModel(config.text_config)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
output_hidden_states=True,
|
||||
**kwargs,
|
||||
):
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
if output_hidden_states:
|
||||
return outputs.last_hidden_state, outputs.hidden_states
|
||||
return outputs.last_hidden_state
|
||||
912
diffsynth/models/stable_diffusion_unet.py
Normal file
912
diffsynth/models/stable_diffusion_unet.py
Normal file
@@ -0,0 +1,912 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# ===== Time Embedding =====
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.freq_shift = freq_shift
|
||||
|
||||
def forward(self, timesteps):
|
||||
half_dim = self.num_channels // 2
|
||||
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / half_dim + self.freq_shift
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
sin_emb = torch.sin(emb)
|
||||
cos_emb = torch.cos(emb)
|
||||
if self.flip_sin_to_cos:
|
||||
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||
else:
|
||||
emb = torch.cat([sin_emb, cos_emb], dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
|
||||
out_dim = out_dim if out_dim is not None else time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
# ===== ResNet Blocks =====
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
use_in_shortcut=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
if self.time_embedding_norm == "default":
|
||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif non_linearity == "gelu":
|
||||
self.nonlinearity = nn.GELU()
|
||||
elif non_linearity == "relu":
|
||||
self.nonlinearity = nn.ReLU()
|
||||
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.conv_shortcut = None
|
||||
if conv_shortcut:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
||||
|
||||
def forward(self, input_tensor, temb=None):
|
||||
hidden_states = input_tensor
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if temb is not None:
|
||||
temb = self.nonlinearity(temb)
|
||||
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "default":
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
return output_tensor
|
||||
|
||||
|
||||
# ===== Transformer Blocks =====
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||
return hidden_states * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, dropout=0.0):
|
||||
super().__init__()
|
||||
self.net = nn.ModuleList([
|
||||
GEGLU(dim, dim * 4),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Attention block matching diffusers checkpoint key format.
|
||||
Keys: to_q.weight, to_k.weight, to_v.weight, to_out.0.weight, to_out.0.bias
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
bias=False,
|
||||
upcast_attention=False,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.inner_dim = inner_dim
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_out = nn.ModuleList([
|
||||
nn.Linear(inner_dim, query_dim, bias=True),
|
||||
nn.Dropout(dropout),
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
# Query
|
||||
query = self.to_q(hidden_states)
|
||||
batch_size, seq_len, _ = query.shape
|
||||
|
||||
# Key/Value
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
head_dim = self.inner_dim // self.heads
|
||||
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Scaled dot-product attention
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Output projection
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
cross_attention_dim=None,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
bias=False,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
bias=False,
|
||||
upcast_attention=upcast_attention,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.ff = FeedForward(dim, dropout=dropout)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
# Self-attention
|
||||
attn_output = self.attn1(self.norm1(hidden_states))
|
||||
hidden_states = attn_output + hidden_states
|
||||
# Cross-attention
|
||||
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = attn_output + hidden_states
|
||||
# Feed-forward
|
||||
ff_output = self.ff(self.norm3(hidden_states))
|
||||
hidden_states = ff_output + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Transformer2DModel(nn.Module):
|
||||
"""2D Transformer block wrapper matching diffusers checkpoint structure.
|
||||
Keys: norm.weight/bias, proj_in.weight/bias, transformer_blocks.X.*, proj_out.weight/bias
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads=16,
|
||||
attention_head_dim=64,
|
||||
in_channels=320,
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
norm_num_groups=32,
|
||||
cross_attention_dim=768,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
|
||||
self.proj_in = nn.Conv2d(in_channels, num_attention_heads * attention_head_dim, kernel_size=1, bias=True)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(
|
||||
dim=num_attention_heads * attention_head_dim,
|
||||
n_heads=num_attention_heads,
|
||||
d_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.proj_out = nn.Conv2d(num_attention_heads * attention_head_dim, in_channels, kernel_size=1, bias=True)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
# Normalize and project to sequence
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
||||
|
||||
# Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
# Project back to 2D
|
||||
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ===== Down/Up Blocks =====
|
||||
|
||||
class CrossAttnDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
temb_channels=1280,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
transformer_layers_per_block=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
cross_attention_dim=768,
|
||||
attention_head_dim=1,
|
||||
downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.has_cross_attention = True
|
||||
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels_i = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels_i,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=attention_head_dim,
|
||||
attention_head_dim=out_channels // attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
dropout=dropout,
|
||||
norm_num_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if downsample:
|
||||
self.downsamplers = nn.ModuleList([
|
||||
Downsample2D(out_channels, out_channels, padding=1)
|
||||
])
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = []
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
output_states.append(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
output_states.append(hidden_states)
|
||||
|
||||
return hidden_states, tuple(output_states)
|
||||
|
||||
|
||||
class DownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
temb_channels=1280,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.has_cross_attention = False
|
||||
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
in_channels_i = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels_i,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if downsample:
|
||||
self.downsamplers = nn.ModuleList([
|
||||
Downsample2D(out_channels, out_channels, padding=1)
|
||||
])
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = []
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
output_states.append(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
output_states.append(hidden_states)
|
||||
|
||||
return hidden_states, tuple(output_states)
|
||||
|
||||
|
||||
class CrossAttnUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
prev_output_channel,
|
||||
temb_channels=1280,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
transformer_layers_per_block=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
cross_attention_dim=768,
|
||||
attention_head_dim=1,
|
||||
upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.has_cross_attention = True
|
||||
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=attention_head_dim,
|
||||
attention_head_dim=out_channels // attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
dropout=dropout,
|
||||
norm_num_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if upsample:
|
||||
self.upsamplers = nn.ModuleList([
|
||||
Upsample2D(out_channels, out_channels)
|
||||
])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# Pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
prev_output_channel,
|
||||
temb_channels=1280,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.has_cross_attention = False
|
||||
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if upsample:
|
||||
self.upsamplers = nn.ModuleList([
|
||||
Upsample2D(out_channels, out_channels)
|
||||
])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||
for resnet in self.resnets:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ===== UNet Mid Block =====
|
||||
|
||||
class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
temb_channels=1280,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
transformer_layers_per_block=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
cross_attention_dim=768,
|
||||
attention_head_dim=1,
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# There is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=attention_head_dim,
|
||||
attention_head_dim=in_channels // attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
dropout=dropout,
|
||||
norm_num_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ===== Downsample / Upsample =====
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, padding=1):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
|
||||
self.padding = padding
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if self.padding == 0:
|
||||
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, hidden_states, upsample_size=None):
|
||||
if upsample_size is not None:
|
||||
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
# ===== UNet2DConditionModel =====
|
||||
|
||||
class UNet2DConditionModel(nn.Module):
|
||||
"""Stable Diffusion UNet with cross-attention conditioning.
|
||||
state_dict keys match the diffusers UNet2DConditionModel checkpoint format.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
sample_size=64,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels=(320, 640, 1280, 1280),
|
||||
layers_per_block=2,
|
||||
cross_attention_dim=768,
|
||||
attention_head_dim=8,
|
||||
norm_num_groups=32,
|
||||
norm_eps=1e-5,
|
||||
dropout=0.0,
|
||||
act_fn="silu",
|
||||
time_embedding_type="positional",
|
||||
flip_sin_to_cos=True,
|
||||
freq_shift=0,
|
||||
time_embedding_dim=None,
|
||||
resnet_time_scale_shift="default",
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.sample_size = sample_size
|
||||
|
||||
# Time embedding
|
||||
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
|
||||
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
|
||||
|
||||
# Input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
|
||||
# Down blocks
|
||||
self.down_blocks = nn.ModuleList()
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
if "CrossAttn" in down_block_type:
|
||||
down_block = CrossAttnDownBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=1,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
downsample=not is_final_block,
|
||||
)
|
||||
else:
|
||||
down_block = DownBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
downsample=not is_final_block,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# Mid block
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=1,
|
||||
transformer_layers_per_block=1,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
|
||||
# Up blocks
|
||||
self.up_blocks = nn.ModuleList()
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
# in_channels for up blocks: diffusers uses reversed_block_out_channels[min(i+1, len-1)]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
if "CrossAttn" in up_block_type:
|
||||
up_block = CrossAttnUpBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block + 1,
|
||||
transformer_layers_per_block=1,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
upsample=not is_final_block,
|
||||
)
|
||||
else:
|
||||
up_block = UpBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block + 1,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
upsample=not is_final_block,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# Output
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
|
||||
# 1. Time embedding
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. Pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. Down
|
||||
down_block_res_samples = (sample,)
|
||||
for down_block in self.down_blocks:
|
||||
sample, res_samples = down_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. Mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
# 5. Up
|
||||
for up_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-len(up_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
|
||||
|
||||
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
|
||||
sample = up_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
|
||||
# 6. Post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
return sample
|
||||
642
diffsynth/models/stable_diffusion_vae.py
Normal file
642
diffsynth/models/stable_diffusion_vae.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
||||
)
|
||||
|
||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
||||
# randn_like doesn't accept generator on all torch versions
|
||||
sample = torch.randn(self.mean.shape, generator=generator,
|
||||
device=self.parameters.device, dtype=self.parameters.dtype)
|
||||
return self.mean + self.std * sample
|
||||
|
||||
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
|
||||
if self.deterministic:
|
||||
return torch.tensor([0.0])
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def mode(self) -> torch.Tensor:
|
||||
return self.mean
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
use_in_shortcut=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
if self.time_embedding_norm == "default":
|
||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif non_linearity == "gelu":
|
||||
self.nonlinearity = nn.GELU()
|
||||
elif non_linearity == "relu":
|
||||
self.nonlinearity = nn.ReLU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported non_linearity: {non_linearity}")
|
||||
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.conv_shortcut = None
|
||||
if conv_shortcut:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
||||
|
||||
def forward(self, input_tensor, temb=None):
|
||||
hidden_states = input_tensor
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if temb is not None:
|
||||
temb = self.nonlinearity(temb)
|
||||
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "default":
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
return output_tensor
|
||||
|
||||
|
||||
class DownEncoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
in_channels_i = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels_i,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList([
|
||||
Downsample2D(out_channels, out_channels, padding=downsample_padding)
|
||||
])
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states, *args, **kwargs):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpDecoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
temb_channels=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
in_channels_i = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels_i,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([
|
||||
Upsample2D(out_channels, out_channels)
|
||||
])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=temb)
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetMidBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
temb_channels=None,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
add_attention=True,
|
||||
attention_head_dim=1,
|
||||
output_scale_factor=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
self.add_attention = add_attention
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = in_channels
|
||||
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
in_channels,
|
||||
num_groups=resnet_groups,
|
||||
eps=resnet_eps,
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(None)
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Simple attention block for VAE mid block.
|
||||
Mirrors diffusers Attention class with AttnProcessor2_0 for VAE use case.
|
||||
Uses modern key names (to_q, to_k, to_v, to_out) matching in-memory diffusers structure.
|
||||
Checkpoint uses deprecated keys (query, key, value, proj_attn) — mapped via converter.
|
||||
"""
|
||||
def __init__(self, channels, num_groups=32, eps=1e-6):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
self.heads = 1
|
||||
self.rescale_output_factor = 1.0
|
||||
|
||||
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=eps, affine=True)
|
||||
self.to_q = nn.Linear(channels, channels, bias=True)
|
||||
self.to_k = nn.Linear(channels, channels, bias=True)
|
||||
self.to_v = nn.Linear(channels, channels, bias=True)
|
||||
self.to_out = nn.ModuleList([
|
||||
nn.Linear(channels, channels, bias=True),
|
||||
nn.Dropout(0.0),
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
|
||||
# Group norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
# Flatten spatial dims: (B, C, H, W) -> (B, H*W, C)
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
# QKV projection
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
|
||||
# Reshape for attention: (B, seq, dim) -> (B, heads, seq, head_dim)
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // self.heads
|
||||
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Scaled dot-product attention
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Reshape back: (B, heads, seq, head_dim) -> (B, seq, heads*head_dim)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Output projection + dropout
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
# Reshape back to 4D and add residual
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# Rescale output factor
|
||||
hidden_states = hidden_states / self.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
"""Downsampling layer matching diffusers Downsample2D with use_conv=True.
|
||||
Key names: conv.weight/bias.
|
||||
When padding=0, applies explicit F.pad before conv to match dimension.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, padding=1):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.padding = padding
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if self.padding == 0:
|
||||
import torch.nn.functional as F
|
||||
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
"""Upsampling layer with key names matching diffusers checkpoint: conv.weight/bias."""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
double_z=True,
|
||||
mid_block_add_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
down_block = DownEncoderBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
add_downsample=not is_final_block,
|
||||
downsample_padding=0,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default",
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=None,
|
||||
add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.conv_in(sample)
|
||||
for down_block in self.down_blocks:
|
||||
sample = down_block(sample)
|
||||
sample = self.mid_block(sample)
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
norm_type="group",
|
||||
mid_block_add_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
temb_channels = in_channels if norm_type == "spatial" else None
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=temb_channels,
|
||||
add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
up_block = UpDecoderBlock2D(
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
add_upsample=not is_final_block,
|
||||
temb_channels=temb_channels,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, sample, latent_embeds=None):
|
||||
sample = self.conv_in(sample)
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
for up_block in self.up_blocks:
|
||||
sample = up_block(sample, latent_embeds)
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class StableDiffusionVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
|
||||
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
|
||||
block_out_channels=(128, 256, 512, 512),
|
||||
layers_per_block=2,
|
||||
act_fn="silu",
|
||||
latent_channels=4,
|
||||
norm_num_groups=32,
|
||||
sample_size=512,
|
||||
scaling_factor=0.18215,
|
||||
shift_factor=None,
|
||||
latents_mean=None,
|
||||
latents_std=None,
|
||||
force_upcast=True,
|
||||
use_quant_conv=True,
|
||||
use_post_quant_conv=True,
|
||||
mid_block_add_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
act_fn=act_fn,
|
||||
double_z=True,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
act_fn=act_fn,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
||||
|
||||
self.latents_mean = latents_mean
|
||||
self.latents_std = latents_std
|
||||
self.scaling_factor = scaling_factor
|
||||
self.shift_factor = shift_factor
|
||||
self.sample_size = sample_size
|
||||
self.force_upcast = force_upcast
|
||||
|
||||
def _encode(self, x):
|
||||
h = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def encode(self, x):
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
return posterior
|
||||
|
||||
def _decode(self, z):
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
return self.decoder(z)
|
||||
|
||||
def decode(self, z):
|
||||
return self._decode(z)
|
||||
|
||||
def forward(self, sample, sample_posterior=True, return_dict=True, generator=None):
|
||||
posterior = self.encode(sample)
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
# Scale latent
|
||||
z = z * self.scaling_factor
|
||||
decode = self.decode(z)
|
||||
if return_dict:
|
||||
return {"sample": decode, "posterior": posterior, "latent_sample": z}
|
||||
return decode, posterior
|
||||
62
diffsynth/models/stable_diffusion_xl_text_encoder.py
Normal file
62
diffsynth/models/stable_diffusion_xl_text_encoder.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
|
||||
|
||||
class SDXLTextEncoder2(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1280,
|
||||
intermediate_size=5120,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
max_position_embeddings=77,
|
||||
vocab_size=49408,
|
||||
layer_norm_eps=1e-05,
|
||||
hidden_act="gelu",
|
||||
initializer_factor=1.0,
|
||||
initializer_range=0.02,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
projection_dim=1280,
|
||||
):
|
||||
super().__init__()
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection
|
||||
|
||||
config = CLIPTextConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
vocab_size=vocab_size,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_act=hidden_act,
|
||||
initializer_factor=initializer_factor,
|
||||
initializer_range=initializer_range,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
projection_dim=projection_dim,
|
||||
)
|
||||
self.model = CLIPTextModelWithProjection(config)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
output_hidden_states=True,
|
||||
**kwargs,
|
||||
):
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
if output_hidden_states:
|
||||
return outputs.text_embeds, outputs.hidden_states
|
||||
return outputs.text_embeds
|
||||
922
diffsynth/models/stable_diffusion_xl_unet.py
Normal file
922
diffsynth/models/stable_diffusion_xl_unet.py
Normal file
@@ -0,0 +1,922 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# ===== Time Embedding =====
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.freq_shift = freq_shift
|
||||
|
||||
def forward(self, timesteps):
|
||||
half_dim = self.num_channels // 2
|
||||
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
exponent = exponent / half_dim + self.freq_shift
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
sin_emb = torch.sin(emb)
|
||||
cos_emb = torch.cos(emb)
|
||||
if self.flip_sin_to_cos:
|
||||
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
||||
else:
|
||||
emb = torch.cat([sin_emb, cos_emb], dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
|
||||
out_dim = out_dim if out_dim is not None else time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
# ===== ResNet Blocks =====
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
use_in_shortcut=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
if self.time_embedding_norm == "default":
|
||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif non_linearity == "gelu":
|
||||
self.nonlinearity = nn.GELU()
|
||||
elif non_linearity == "relu":
|
||||
self.nonlinearity = nn.ReLU()
|
||||
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.conv_shortcut = None
|
||||
if conv_shortcut:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
||||
|
||||
def forward(self, input_tensor, temb=None):
|
||||
hidden_states = input_tensor
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if temb is not None:
|
||||
temb = self.nonlinearity(temb)
|
||||
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "default":
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
return output_tensor
|
||||
|
||||
|
||||
# ===== Transformer Blocks =====
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||
return hidden_states * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, dropout=0.0):
|
||||
super().__init__()
|
||||
self.net = nn.ModuleList([
|
||||
GEGLU(dim, dim * 4),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
bias=False,
|
||||
upcast_attention=False,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.inner_dim = inner_dim
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_out = nn.ModuleList([
|
||||
nn.Linear(inner_dim, query_dim, bias=True),
|
||||
nn.Dropout(dropout),
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
query = self.to_q(hidden_states)
|
||||
batch_size, seq_len, _ = query.shape
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
head_dim = self.inner_dim // self.heads
|
||||
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
cross_attention_dim=None,
|
||||
upcast_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
bias=False,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
bias=False,
|
||||
upcast_attention=upcast_attention,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.ff = FeedForward(dim, dropout=dropout)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
attn_output = self.attn1(self.norm1(hidden_states))
|
||||
hidden_states = attn_output + hidden_states
|
||||
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = attn_output + hidden_states
|
||||
ff_output = self.ff(self.norm3(hidden_states))
|
||||
hidden_states = ff_output + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Transformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads=16,
|
||||
attention_head_dim=64,
|
||||
in_channels=320,
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
norm_num_groups=32,
|
||||
cross_attention_dim=768,
|
||||
upcast_attention=False,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.use_linear_projection = use_linear_projection
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
|
||||
|
||||
if use_linear_projection:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim, bias=True)
|
||||
else:
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, bias=True)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(
|
||||
dim=inner_dim,
|
||||
n_heads=num_attention_heads,
|
||||
d_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
if use_linear_projection:
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels, bias=True)
|
||||
else:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, bias=True)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if self.use_linear_projection:
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
if self.use_linear_projection:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
||||
else:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ===== Down/Up Blocks =====
|
||||
|
||||
class CrossAttnDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
temb_channels=1280,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
transformer_layers_per_block=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
cross_attention_dim=768,
|
||||
attention_head_dim=1,
|
||||
downsample=True,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.has_cross_attention = True
|
||||
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels_i = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels_i,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=attention_head_dim,
|
||||
attention_head_dim=out_channels // attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
dropout=dropout,
|
||||
norm_num_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if downsample:
|
||||
self.downsamplers = nn.ModuleList([
|
||||
Downsample2D(out_channels, out_channels, padding=1)
|
||||
])
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = []
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
output_states.append(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
output_states.append(hidden_states)
|
||||
|
||||
return hidden_states, tuple(output_states)
|
||||
|
||||
|
||||
class DownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
temb_channels=1280,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.has_cross_attention = False
|
||||
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
in_channels_i = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels_i,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if downsample:
|
||||
self.downsamplers = nn.ModuleList([
|
||||
Downsample2D(out_channels, out_channels, padding=1)
|
||||
])
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = []
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
output_states.append(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
output_states.append(hidden_states)
|
||||
|
||||
return hidden_states, tuple(output_states)
|
||||
|
||||
|
||||
class CrossAttnUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
prev_output_channel,
|
||||
temb_channels=1280,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
transformer_layers_per_block=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
cross_attention_dim=768,
|
||||
attention_head_dim=1,
|
||||
upsample=True,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.has_cross_attention = True
|
||||
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=attention_head_dim,
|
||||
attention_head_dim=out_channels // attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
dropout=dropout,
|
||||
norm_num_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if upsample:
|
||||
self.upsamplers = nn.ModuleList([
|
||||
Upsample2D(out_channels, out_channels)
|
||||
])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
prev_output_channel,
|
||||
temb_channels=1280,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.has_cross_attention = False
|
||||
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if upsample:
|
||||
self.upsamplers = nn.ModuleList([
|
||||
Upsample2D(out_channels, out_channels)
|
||||
])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
||||
for resnet in self.resnets:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ===== UNet Mid Block =====
|
||||
|
||||
class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
temb_channels=1280,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
transformer_layers_per_block=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
cross_attention_dim=768,
|
||||
attention_head_dim=1,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=attention_head_dim,
|
||||
attention_head_dim=in_channels // attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=transformer_layers_per_block,
|
||||
dropout=dropout,
|
||||
norm_num_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=1.0,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# ===== Downsample / Upsample =====
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, padding=1):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
|
||||
self.padding = padding
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if self.padding == 0:
|
||||
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, hidden_states, upsample_size=None):
|
||||
if upsample_size is not None:
|
||||
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
# ===== SDXL UNet2DConditionModel =====
|
||||
|
||||
class SDXLUNet2DConditionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sample_size=128,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels=(320, 640, 1280),
|
||||
layers_per_block=2,
|
||||
cross_attention_dim=2048,
|
||||
attention_head_dim=5,
|
||||
transformer_layers_per_block=1,
|
||||
norm_num_groups=32,
|
||||
norm_eps=1e-5,
|
||||
dropout=0.0,
|
||||
act_fn="silu",
|
||||
time_embedding_type="positional",
|
||||
flip_sin_to_cos=True,
|
||||
freq_shift=0,
|
||||
time_embedding_dim=None,
|
||||
resnet_time_scale_shift="default",
|
||||
upcast_attention=False,
|
||||
use_linear_projection=False,
|
||||
addition_embed_type=None,
|
||||
addition_time_embed_dim=None,
|
||||
projection_class_embeddings_input_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.sample_size = sample_size
|
||||
self.addition_embed_type = addition_embed_type
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
|
||||
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
|
||||
|
||||
if addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
|
||||
self.down_blocks = nn.ModuleList()
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
if "CrossAttn" in down_block_type:
|
||||
down_block = CrossAttnDownBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim[i],
|
||||
downsample=not is_final_block,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
else:
|
||||
down_block = DownBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
downsample=not is_final_block,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=1,
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList()
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
if "CrossAttn" in up_block_type:
|
||||
up_block = CrossAttnUpBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block + 1,
|
||||
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_head_dim=reversed_attention_head_dim[i],
|
||||
upsample=not is_final_block,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
else:
|
||||
up_block = UpBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block + 1,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
upsample=not is_final_block,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.addition_embed_type == "text_time":
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
emb = emb + aug_emb
|
||||
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
for down_block in self.down_blocks:
|
||||
sample, res_samples = down_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-len(up_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
|
||||
|
||||
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
|
||||
sample = up_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
return sample
|
||||
222
diffsynth/pipelines/stable_diffusion.py
Normal file
222
diffsynth/pipelines/stable_diffusion.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Union
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion.ddim_scheduler import DDIMScheduler
|
||||
from ..core import ModelConfig
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
|
||||
from transformers import AutoTokenizer, CLIPTextModel
|
||||
from ..models.stable_diffusion_text_encoder import SDTextEncoder
|
||||
from ..models.stable_diffusion_unet import UNet2DConditionModel
|
||||
from ..models.stable_diffusion_vae import StableDiffusionVAE
|
||||
|
||||
|
||||
class StableDiffusionPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.float16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=8, width_division_factor=8,
|
||||
)
|
||||
self.scheduler = DDIMScheduler()
|
||||
self.text_encoder: SDTextEncoder = None
|
||||
self.unet: UNet2DConditionModel = None
|
||||
self.vae: StableDiffusionVAE = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
|
||||
self.in_iteration_models = ("unet",)
|
||||
self.units = [
|
||||
SDUnit_ShapeChecker(),
|
||||
SDUnit_PromptEmbedder(),
|
||||
SDUnit_NoiseInitializer(),
|
||||
SDUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_stable_diffusion
|
||||
self.compilable_models = ["unet"]
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.float16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
pipe = StableDiffusionPipeline(device=device, torch_dtype=torch_dtype)
|
||||
# Override vram_config to use the specified torch_dtype for all models
|
||||
for mc in model_configs:
|
||||
mc._vram_config_override = {
|
||||
'onload_dtype': torch_dtype,
|
||||
'computation_dtype': torch_dtype,
|
||||
}
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
|
||||
pipe.unet = model_pool.fetch_model("stable_diffusion_unet")
|
||||
pipe.vae = model_pool.fetch_model("stable_diffusion_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 7.5,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
num_inference_steps: int = 50,
|
||||
eta: float = 0.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# 1. Scheduler
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps, eta=eta,
|
||||
)
|
||||
|
||||
# 2. Three-dict input preparation
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"guidance_rescale": guidance_rescale,
|
||||
}
|
||||
|
||||
# 3. Unit chain execution
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||
)
|
||||
|
||||
# 4. Denoise loop
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(
|
||||
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
|
||||
)
|
||||
|
||||
# 5. VAE decode
|
||||
self.load_models_to_device(['vae'])
|
||||
latents = inputs_shared["latents"] / self.vae.scaling_factor
|
||||
image = self.vae.decode(latents)
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class SDUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width"),
|
||||
output_params=("height", "width"),
|
||||
)
|
||||
|
||||
def process(self, pipe: StableDiffusionPipeline, height, width):
|
||||
height, width = pipe.check_resize_height_width(height, width)
|
||||
return {"height": height, "width": width}
|
||||
|
||||
|
||||
class SDUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
pipe: StableDiffusionPipeline,
|
||||
prompt: str,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
text_inputs = pipe.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=pipe.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_embeds = pipe.text_encoder(text_input_ids)
|
||||
# TextEncoder returns (last_hidden_state, hidden_states) or just last_hidden_state.
|
||||
# last_hidden_state is the post-final-layer-norm output, matching diffusers encode_prompt.
|
||||
if isinstance(prompt_embeds, tuple):
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
return prompt_embeds
|
||||
|
||||
def process(self, pipe: StableDiffusionPipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
class SDUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: StableDiffusionPipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise(
|
||||
(1, pipe.unet.in_channels, height // 8, width // 8),
|
||||
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
|
||||
)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class SDUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("noise",),
|
||||
output_params=("latents",),
|
||||
)
|
||||
|
||||
def process(self, pipe: StableDiffusionPipeline, noise):
|
||||
# For Text-to-Image, latents = noise (scaled by scheduler)
|
||||
latents = noise * pipe.scheduler.init_noise_sigma
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
def model_fn_stable_diffusion(
|
||||
unet: UNet2DConditionModel,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
cross_attention_kwargs=None,
|
||||
timestep_cond=None,
|
||||
added_cond_kwargs=None,
|
||||
**kwargs,
|
||||
):
|
||||
# SD timestep is already in 0-999 range, no scaling needed
|
||||
noise_pred = unet(
|
||||
latents,
|
||||
timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
timestep_cond=timestep_cond,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
return noise_pred
|
||||
@@ -0,0 +1,7 @@
|
||||
def SDTextEncoderStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
if key.startswith("text_model.") and "position_ids" not in key:
|
||||
new_key = "model." + key
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
return new_state_dict
|
||||
@@ -0,0 +1,18 @@
|
||||
def SDVAEStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
if ".query." in key:
|
||||
new_key = key.replace(".query.", ".to_q.")
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
elif ".key." in key:
|
||||
new_key = key.replace(".key.", ".to_k.")
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
elif ".value." in key:
|
||||
new_key = key.replace(".value.", ".to_v.")
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
elif ".proj_attn." in key:
|
||||
new_key = key.replace(".proj_attn.", ".to_out.0.")
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
else:
|
||||
new_state_dict[key] = state_dict[key]
|
||||
return new_state_dict
|
||||
@@ -0,0 +1,7 @@
|
||||
def SDXLTextEncoder2StateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
if key.startswith("text_model.") and "position_ids" not in key:
|
||||
new_key = "model." + key
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
return new_state_dict
|
||||
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from diffsynth.core import ModelConfig
|
||||
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
torch_dtype=torch.float32,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="a photo of an astronaut riding a horse on mars",
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
height=512,
|
||||
width=512,
|
||||
seed=42,
|
||||
num_inference_steps=50,
|
||||
)
|
||||
image.save("output_stable_diffusion_t2i.png")
|
||||
print("Image saved to output_stable_diffusion_t2i.png")
|
||||
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
from diffsynth.core import ModelConfig
|
||||
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float32,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float32,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float32,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.float32,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
torch_dtype=torch.float32,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="a photo of an astronaut riding a horse on mars",
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
height=512,
|
||||
width=512,
|
||||
seed=42,
|
||||
num_inference_steps=50,
|
||||
)
|
||||
image.save("output_stable_diffusion_t2i_low_vram.png")
|
||||
print("Image saved to output_stable_diffusion_t2i_low_vram.png")
|
||||
Reference in New Issue
Block a user