From 82e482286c7f59675e030095c99852a4bb8e9754 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Thu, 23 Apr 2026 17:35:24 +0800 Subject: [PATCH] sd --- diffsynth/configs/model_configs.py | 49 +- .../configs/vram_management_module_maps.py | 39 + diffsynth/diffusion/ddim_scheduler.py | 255 +++++ .../models/stable_diffusion_text_encoder.py | 78 ++ diffsynth/models/stable_diffusion_unet.py | 912 +++++++++++++++++ diffsynth/models/stable_diffusion_vae.py | 642 ++++++++++++ .../stable_diffusion_xl_text_encoder.py | 62 ++ diffsynth/models/stable_diffusion_xl_unet.py | 922 ++++++++++++++++++ diffsynth/pipelines/stable_diffusion.py | 222 +++++ .../stable_diffusion_text_encoder.py | 7 + .../stable_diffusion_vae.py | 18 + .../stable_diffusion_xl_text_encoder.py | 7 + .../model_inference/StableDiffusion-T2I.py | 25 + .../StableDiffusion-T2I.py | 37 + 14 files changed, 3274 insertions(+), 1 deletion(-) create mode 100644 diffsynth/diffusion/ddim_scheduler.py create mode 100644 diffsynth/models/stable_diffusion_text_encoder.py create mode 100644 diffsynth/models/stable_diffusion_unet.py create mode 100644 diffsynth/models/stable_diffusion_vae.py create mode 100644 diffsynth/models/stable_diffusion_xl_text_encoder.py create mode 100644 diffsynth/models/stable_diffusion_xl_unet.py create mode 100644 diffsynth/pipelines/stable_diffusion.py create mode 100644 diffsynth/utils/state_dict_converters/stable_diffusion_text_encoder.py create mode 100644 diffsynth/utils/state_dict_converters/stable_diffusion_vae.py create mode 100644 diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py create mode 100644 examples/stable_diffusion/model_inference/StableDiffusion-T2I.py create mode 100644 examples/stable_diffusion/model_inference_low_vram/StableDiffusion-T2I.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 5fc95c3..f017aaa 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -900,6 +900,53 @@ mova_series = [ "model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge", }, ] +stable_diffusion_xl_series = [ + { + # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors") + "model_hash": "142b114f67f5ab3a6d83fb5788f12ded", + "model_name": "stable_diffusion_xl_unet", + "model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel", + }, + { + # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors") + "model_hash": "98cc34ccc5b54ae0e56bdea8688dcd5a", + "model_name": "stable_diffusion_xl_text_encoder", + "model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2", + "state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter", + }, + { + # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "13115dd45a6e1c39860f91ab073b8a78", + "model_name": "stable_diffusion_xl_vae", + "model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE", + "state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter", + "extra_kwargs": {"scaling_factor": 0.13025, "sample_size": 1024, "force_upcast": True}, + }, +] + +stable_diffusion_series = [ + { + # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors") + "model_hash": "ffd1737ae9df7fd43f5fbed653bdad67", + "model_name": "stable_diffusion_text_encoder", + "model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder", + "state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter", + }, + { + # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors") + "model_hash": "f86d5683ed32433be8ca69969c67ba69", + "model_name": "stable_diffusion_vae", + "model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE", + "state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter", + }, + { + # Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors") + "model_hash": "025a4b86a84829399d89f613e580757b", + "model_name": "stable_diffusion_unet", + "model_class": "diffsynth.models.stable_diffusion_unet.UNet2DConditionModel", + }, +] + joyai_image_series = [ { # Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth") @@ -916,4 +963,4 @@ joyai_image_series = [ }, ] -MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series +MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + stable_diffusion_xl_series + stable_diffusion_series + joyai_image_series diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index 8d4800b..1970135 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -295,6 +295,45 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", }, + "diffsynth.models.stable_diffusion_unet.UNet2DConditionModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.stable_diffusion_vae.Upsample2D": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.stable_diffusion_vae.Downsample2D": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } def QwenImageTextEncoder_Module_Map_Updater(): diff --git a/diffsynth/diffusion/ddim_scheduler.py b/diffsynth/diffusion/ddim_scheduler.py new file mode 100644 index 0000000..f131f4a --- /dev/null +++ b/diffsynth/diffusion/ddim_scheduler.py @@ -0,0 +1,255 @@ +import torch, math +from typing import Literal + + +class DDIMScheduler: + + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear", + clip_sample: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 1, + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", + timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading", + rescale_betas_zero_snr: bool = False, + ): + self.num_train_timesteps = num_train_timesteps + self.beta_start = beta_start + self.beta_end = beta_end + self.beta_schedule = beta_schedule + self.clip_sample = clip_sample + self.set_alpha_to_one = set_alpha_to_one + self.steps_offset = steps_offset + self.prediction_type = prediction_type + self.timestep_spacing = timestep_spacing + + # Compute betas + if beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # SD 1.5 specific: sqrt-linear interpolation + self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + self.betas = self._betas_for_alpha_bar(num_train_timesteps) + else: + raise ValueError(f"Unsupported beta_schedule: {beta_schedule}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = self._rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # For the final step, there is no previous alphas_cumprod + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # Setable values (will be populated by set_timesteps) + self.num_inference_steps = None + self.timesteps = torch.from_numpy(self._default_timesteps().astype("int64")) + self.training = False + + @staticmethod + def _betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta: float = 0.999) -> torch.Tensor: + """Create beta schedule via cosine alpha_bar function.""" + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + @staticmethod + def _rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: + """Rescale betas to have zero terminal SNR.""" + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + alphas_bar_sqrt -= alphas_bar_sqrt_T + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + alphas_bar = alphas_bar_sqrt ** 2 + alphas = torch.cat([alphas_bar[1:], alphas_bar[:1]]) + return 1 - alphas + + def _default_timesteps(self): + """Default timesteps before set_timesteps is called.""" + import numpy as np + return np.arange(0, self.num_train_timesteps)[::-1].copy().astype(np.int64) + + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: + """Compute the variance for the DDIM step.""" + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + return variance + + def set_timesteps(self, num_inference_steps: int = 100, denoising_strength: float = 1.0, training: bool = False, **kwargs): + """ + Sets the discrete timesteps used for the diffusion chain. + Follows FlowMatchScheduler interface: (num_inference_steps, denoising_strength, training, **kwargs) + """ + import numpy as np + + if denoising_strength != 1.0: + # For img2img: adjust effective steps + num_inference_steps = int(num_inference_steps * denoising_strength) + + # Compute step ratio + if self.timestep_spacing == "leading": + # leading: arange * step_ratio, reverse, then add offset + step_ratio = self.num_train_timesteps // num_inference_steps + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) + timesteps = timesteps + self.steps_offset + elif self.timestep_spacing == "trailing": + # trailing: timesteps = arange(num_steps, 0, -1) * step_ratio - 1 + step_ratio = self.num_train_timesteps / num_inference_steps + timesteps = (np.arange(num_inference_steps, 0, -1) * step_ratio - 1).round()[::-1] + elif self.timestep_spacing == "linspace": + # linspace: evenly spaced from num_train_timesteps - 1 to 0 + timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps).round()[::-1] + else: + raise ValueError(f"Unsupported timestep_spacing: {self.timestep_spacing}") + + self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.int64) + self.num_inference_steps = num_inference_steps + + if training: + self.set_training_weight() + self.training = True + else: + self.training = False + + def set_training_weight(self): + """Set timestep weights for training (similar to FlowMatchScheduler).""" + steps = 1000 + x = self.timesteps + y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (steps / y_shifted.sum()) + if len(self.timesteps) != 1000: + bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps) + bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1] + self.linear_timesteps_weights = bsmntw_weighing + + def step(self, model_output, timestep, sample, to_final: bool = False, eta: float = 0.0, **kwargs): + """ + DDIM step function. + Follows FlowMatchScheduler interface: step(model_output, timestep, sample, to_final=False) + + For SD 1.5, prediction_type="epsilon" and eta=0.0 (deterministic DDIM). + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + if timestep.dim() == 0: + timestep = timestep.item() + elif timestep.dim() == 1: + timestep = timestep[0].item() + + # Ensure timestep is int + timestep = int(timestep) + + # Find the index of the current timestep + timestep_id = torch.argmin((self.timesteps - timestep).abs()).item() + + if timestep_id + 1 >= len(self.timesteps): + prev_timestep = -1 + else: + prev_timestep = self.timesteps[timestep_id + 1].item() + + # Get alphas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + alpha_prod_t = alpha_prod_t.to(device=sample.device, dtype=sample.dtype) + alpha_prod_t_prev = alpha_prod_t_prev.to(device=sample.device, dtype=sample.dtype) + + beta_prod_t = 1 - alpha_prod_t + + # Compute predicted original sample (x_0) + if self.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() + elif self.prediction_type == "sample": + pred_original_sample = model_output + elif self.prediction_type == "v_prediction": + pred_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + else: + raise ValueError(f"Unsupported prediction_type: {self.prediction_type}") + + # Clip sample if needed + if self.clip_sample: + pred_original_sample = pred_original_sample.clamp(-1.0, 1.0) + + # Compute predicted noise (re-derived from x_0) + pred_epsilon = (sample - alpha_prod_t.sqrt() * pred_original_sample) / beta_prod_t.sqrt() + + # DDIM formula: prev_sample = sqrt(alpha_prev) * x0 + sqrt(1 - alpha_prev) * epsilon + prev_sample = alpha_prod_t_prev.sqrt() * pred_original_sample + (1 - alpha_prod_t_prev).sqrt() * pred_epsilon + + # Add variance noise if eta > 0 (DDIM: eta=0, DDPM: eta=1) + if eta > 0: + variance = self._get_variance(timestep, prev_timestep) + variance = variance.to(device=sample.device, dtype=sample.dtype) + std_dev_t = eta * variance.sqrt() + device = sample.device + noise = torch.randn_like(sample) + prev_sample = prev_sample + std_dev_t * noise + + return prev_sample + + def add_noise(self, original_samples, noise, timestep): + """Add noise to original samples (forward diffusion). + Follows FlowMatchScheduler interface: add_noise(original_samples, noise, timestep) + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + if timestep.dim() == 0: + timestep = timestep.item() + elif timestep.dim() == 1: + timestep = timestep[0].item() + + timestep = int(timestep) + sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt() + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt() + + sqrt_alpha_prod = sqrt_alpha_prod.to(device=original_samples.device, dtype=original_samples.dtype) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.to(device=original_samples.device, dtype=original_samples.dtype) + + # Handle broadcasting for batch timesteps + while sqrt_alpha_prod.dim() < original_samples.dim(): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + sample = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return sample + + def training_target(self, sample, noise, timestep): + """Return the training target for the given prediction type.""" + if self.prediction_type == "epsilon": + return noise + elif self.prediction_type == "v_prediction": + sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt() + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt() + return sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + elif self.prediction_type == "sample": + return sample + else: + raise ValueError(f"Unsupported prediction_type: {self.prediction_type}") + + def training_weight(self, timestep): + """Return training weight for the given timestep.""" + timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) + return self.linear_timesteps_weights[timestep_id] diff --git a/diffsynth/models/stable_diffusion_text_encoder.py b/diffsynth/models/stable_diffusion_text_encoder.py new file mode 100644 index 0000000..048e246 --- /dev/null +++ b/diffsynth/models/stable_diffusion_text_encoder.py @@ -0,0 +1,78 @@ +import torch + + +class SDTextEncoder(torch.nn.Module): + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + vocab_size=49408, + layer_norm_eps=1e-05, + hidden_act="quick_gelu", + initializer_factor=1.0, + initializer_range=0.02, + bos_token_id=0, + eos_token_id=2, + pad_token_id=1, + projection_dim=768, + ): + super().__init__() + from transformers import CLIPConfig, CLIPTextModel + + config = CLIPConfig( + text_config={ + "hidden_size": hidden_size, + "intermediate_size": intermediate_size, + "num_hidden_layers": num_hidden_layers, + "num_attention_heads": num_attention_heads, + "max_position_embeddings": max_position_embeddings, + "vocab_size": vocab_size, + "layer_norm_eps": layer_norm_eps, + "hidden_act": hidden_act, + "initializer_factor": initializer_factor, + "initializer_range": initializer_range, + "bos_token_id": bos_token_id, + "eos_token_id": eos_token_id, + "pad_token_id": pad_token_id, + "projection_dim": projection_dim, + "dropout": 0.0, + }, + vision_config={ + "hidden_size": hidden_size, + "intermediate_size": intermediate_size, + "num_hidden_layers": num_hidden_layers, + "num_attention_heads": num_attention_heads, + "max_position_embeddings": max_position_embeddings, + "layer_norm_eps": layer_norm_eps, + "hidden_act": hidden_act, + "initializer_factor": initializer_factor, + "initializer_range": initializer_range, + "projection_dim": projection_dim, + }, + projection_dim=projection_dim, + ) + self.model = CLIPTextModel(config.text_config) + self.config = config + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_hidden_states=True, + **kwargs, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=output_hidden_states, + return_dict=True, + **kwargs, + ) + if output_hidden_states: + return outputs.last_hidden_state, outputs.hidden_states + return outputs.last_hidden_state diff --git a/diffsynth/models/stable_diffusion_unet.py b/diffsynth/models/stable_diffusion_unet.py new file mode 100644 index 0000000..fc6122d --- /dev/null +++ b/diffsynth/models/stable_diffusion_unet.py @@ -0,0 +1,912 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Optional + + +# ===== Time Embedding ===== + +class Timesteps(nn.Module): + def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.freq_shift = freq_shift + + def forward(self, timesteps): + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / half_dim + self.freq_shift + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + if self.flip_sin_to_cos: + emb = torch.cat([cos_emb, sin_emb], dim=-1) + else: + emb = torch.cat([sin_emb, cos_emb], dim=-1) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None): + super().__init__() + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.act = nn.SiLU() if act_fn == "silu" else nn.GELU() + out_dim = out_dim if out_dim is not None else time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, out_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + +# ===== ResNet Blocks ===== + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps) + self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels)) + + self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = nn.SiLU() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() + elif non_linearity == "relu": + self.nonlinearity = nn.ReLU() + + self.use_conv_shortcut = conv_shortcut + self.conv_shortcut = None + if conv_shortcut: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None + + def forward(self, input_tensor, temb=None): + hidden_states = input_tensor + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1) + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + return output_tensor + + +# ===== Transformer Blocks ===== + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, dropout=0.0): + super().__init__() + self.net = nn.ModuleList([ + GEGLU(dim, dim * 4), + nn.Dropout(dropout), + nn.Linear(dim * 4, dim if dim_out is None else dim_out), + ]) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class Attention(nn.Module): + """Attention block matching diffusers checkpoint key format. + Keys: to_q.weight, to_k.weight, to_v.weight, to_out.0.weight, to_out.0.bias + """ + def __init__( + self, + query_dim, + heads=8, + dim_head=64, + dropout=0.0, + bias=False, + upcast_attention=False, + cross_attention_dim=None, + ): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.inner_dim = inner_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias) + self.to_out = nn.ModuleList([ + nn.Linear(inner_dim, query_dim, bias=True), + nn.Dropout(dropout), + ]) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + # Query + query = self.to_q(hidden_states) + batch_size, seq_len, _ = query.shape + + # Key/Value + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # Reshape for multi-head attention + head_dim = self.inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # Scaled dot-product attention + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + # Reshape back + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim) + hidden_states = hidden_states.to(query.dtype) + + # Output projection + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + cross_attention_dim=None, + upcast_attention=False, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn1 = Attention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + ) + self.norm2 = nn.LayerNorm(dim) + self.attn2 = Attention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + cross_attention_dim=cross_attention_dim, + ) + self.norm3 = nn.LayerNorm(dim) + self.ff = FeedForward(dim, dropout=dropout) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + # Self-attention + attn_output = self.attn1(self.norm1(hidden_states)) + hidden_states = attn_output + hidden_states + # Cross-attention + attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + # Feed-forward + ff_output = self.ff(self.norm3(hidden_states)) + hidden_states = ff_output + hidden_states + return hidden_states + + +class Transformer2DModel(nn.Module): + """2D Transformer block wrapper matching diffusers checkpoint structure. + Keys: norm.weight/bias, proj_in.weight/bias, transformer_blocks.X.*, proj_out.weight/bias + """ + def __init__( + self, + num_attention_heads=16, + attention_head_dim=64, + in_channels=320, + num_layers=1, + dropout=0.0, + norm_num_groups=32, + cross_attention_dim=768, + upcast_attention=False, + ): + super().__init__() + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6) + self.proj_in = nn.Conv2d(in_channels, num_attention_heads * attention_head_dim, kernel_size=1, bias=True) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + dim=num_attention_heads * attention_head_dim, + n_heads=num_attention_heads, + d_head=attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + for _ in range(num_layers) + ]) + + self.proj_out = nn.Conv2d(num_attention_heads * attention_head_dim, in_channels, kernel_size=1, bias=True) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch, channel, height, width = hidden_states.shape + residual = hidden_states + + # Normalize and project to sequence + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel) + + # Transformer blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + + # Project back to 2D + hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +# ===== Down/Up Blocks ===== + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + temb_channels=1280, + dropout=0.0, + num_layers=1, + transformer_layers_per_block=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + cross_attention_dim=768, + attention_head_dim=1, + downsample=True, + ): + super().__init__() + self.has_cross_attention = True + + resnets = [] + attentions = [] + + for i in range(num_layers): + in_channels_i = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels_i, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Transformer2DModel( + num_attention_heads=attention_head_dim, + attention_head_dim=out_channels // attention_head_dim, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + dropout=dropout, + norm_num_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if downsample: + self.downsamplers = nn.ModuleList([ + Downsample2D(out_channels, out_channels, padding=1) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + output_states.append(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states.append(hidden_states) + + return hidden_states, tuple(output_states) + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + temb_channels=1280, + dropout=0.0, + num_layers=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + downsample=True, + ): + super().__init__() + self.has_cross_attention = False + + resnets = [] + for i in range(num_layers): + in_channels_i = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels_i, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if downsample: + self.downsamplers = nn.ModuleList([ + Downsample2D(out_channels, out_channels, padding=1) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = [] + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states.append(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states.append(hidden_states) + + return hidden_states, tuple(output_states) + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + prev_output_channel, + temb_channels=1280, + dropout=0.0, + num_layers=1, + transformer_layers_per_block=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + cross_attention_dim=768, + attention_head_dim=1, + upsample=True, + ): + super().__init__() + self.has_cross_attention = True + + resnets = [] + attentions = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Transformer2DModel( + num_attention_heads=attention_head_dim, + attention_head_dim=out_channels // attention_head_dim, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + dropout=dropout, + norm_num_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if upsample: + self.upsamplers = nn.ModuleList([ + Upsample2D(out_channels, out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None): + for resnet, attn in zip(self.resnets, self.attentions): + # Pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size=upsample_size) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + prev_output_channel, + temb_channels=1280, + dropout=0.0, + num_layers=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + upsample=True, + ): + super().__init__() + self.has_cross_attention = False + + resnets = [] + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if upsample: + self.upsamplers = nn.ModuleList([ + Upsample2D(out_channels, out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size=upsample_size) + + return hidden_states + + +# ===== UNet Mid Block ===== + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels, + temb_channels=1280, + dropout=0.0, + num_layers=1, + transformer_layers_per_block=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + cross_attention_dim=768, + attention_head_dim=1, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # There is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + num_attention_heads=attention_head_dim, + attention_head_dim=in_channels // attention_head_dim, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + dropout=dropout, + norm_num_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +# ===== Downsample / Upsample ===== + +class Downsample2D(nn.Module): + def __init__(self, in_channels, out_channels, padding=1): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding) + self.padding = padding + + def forward(self, hidden_states): + if self.padding == 0: + hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) + return self.conv(hidden_states) + + +class Upsample2D(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states, upsample_size=None): + if upsample_size is not None: + hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + return self.conv(hidden_states) + + +# ===== UNet2DConditionModel ===== + +class UNet2DConditionModel(nn.Module): + """Stable Diffusion UNet with cross-attention conditioning. + state_dict keys match the diffusers UNet2DConditionModel checkpoint format. + """ + def __init__( + self, + sample_size=64, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels=(320, 640, 1280, 1280), + layers_per_block=2, + cross_attention_dim=768, + attention_head_dim=8, + norm_num_groups=32, + norm_eps=1e-5, + dropout=0.0, + act_fn="silu", + time_embedding_type="positional", + flip_sin_to_cos=True, + freq_shift=0, + time_embedding_dim=None, + resnet_time_scale_shift="default", + upcast_attention=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.sample_size = sample_size + + # Time embedding + timestep_embedding_dim = time_embedding_dim or block_out_channels[0] + self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift) + time_embed_dim = block_out_channels[0] * 4 + self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim) + + # Input + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1) + + # Down blocks + self.down_blocks = nn.ModuleList() + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if "CrossAttn" in down_block_type: + down_block = CrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + dropout=dropout, + num_layers=layers_per_block, + transformer_layers_per_block=1, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + downsample=not is_final_block, + ) + else: + down_block = DownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + dropout=dropout, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample=not is_final_block, + ) + self.down_blocks.append(down_block) + + # Mid block + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + dropout=dropout, + num_layers=1, + transformer_layers_per_block=1, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + # Up blocks + self.up_blocks = nn.ModuleList() + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + # in_channels for up blocks: diffusers uses reversed_block_out_channels[min(i+1, len-1)] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + if "CrossAttn" in up_block_type: + up_block = CrossAttnUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + dropout=dropout, + num_layers=layers_per_block + 1, + transformer_layers_per_block=1, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + upsample=not is_final_block, + ) + else: + up_block = UpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + dropout=dropout, + num_layers=layers_per_block + 1, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + + # Output + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True): + # 1. Time embedding + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + t_emb = self.time_proj(timesteps) + t_emb = t_emb.to(dtype=sample.dtype) + emb = self.time_embedding(t_emb) + + # 2. Pre-process + sample = self.conv_in(sample) + + # 3. Down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + sample, res_samples = down_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + down_block_res_samples += res_samples + + # 4. Mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. Up + for up_block in self.up_blocks: + res_samples = down_block_res_samples[-len(up_block.resnets):] + down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)] + + upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None + sample = up_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + # 6. Post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + return sample diff --git a/diffsynth/models/stable_diffusion_vae.py b/diffsynth/models/stable_diffusion_vae.py new file mode 100644 index 0000000..291cbeb --- /dev/null +++ b/diffsynth/models/stable_diffusion_vae.py @@ -0,0 +1,642 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from typing import Optional + + +class DiagonalGaussianDistribution: + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + # randn_like doesn't accept generator on all torch versions + sample = torch.randn(self.mean.shape, generator=generator, + device=self.parameters.device, dtype=self.parameters.dtype) + return self.mean + self.std * sample + + def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor: + if self.deterministic: + return torch.tensor([0.0]) + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3], + ) + + def mode(self) -> torch.Tensor: + return self.mean + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps) + self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels)) + + self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = nn.SiLU() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() + elif non_linearity == "relu": + self.nonlinearity = nn.ReLU() + else: + raise ValueError(f"Unsupported non_linearity: {non_linearity}") + + self.use_conv_shortcut = conv_shortcut + self.conv_shortcut = None + if conv_shortcut: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None + + def forward(self, input_tensor, temb=None): + hidden_states = input_tensor + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1) + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + return output_tensor + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + dropout=0.0, + num_layers=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels_i = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels_i, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList([ + Downsample2D(out_channels, out_channels, padding=downsample_padding) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states, *args, **kwargs): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + dropout=0.0, + num_layers=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + output_scale_factor=1.0, + add_upsample=True, + temb_channels=None, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channels_i = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels_i, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([ + Upsample2D(out_channels, out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states, temb=None): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +class UNetMidBlock2D(nn.Module): + def __init__( + self, + in_channels, + temb_channels=None, + dropout=0.0, + num_layers=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + add_attention=True, + attention_head_dim=1, + output_scale_factor=1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + AttentionBlock( + in_channels, + num_groups=resnet_groups, + eps=resnet_eps, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +class AttentionBlock(nn.Module): + """Simple attention block for VAE mid block. + Mirrors diffusers Attention class with AttnProcessor2_0 for VAE use case. + Uses modern key names (to_q, to_k, to_v, to_out) matching in-memory diffusers structure. + Checkpoint uses deprecated keys (query, key, value, proj_attn) — mapped via converter. + """ + def __init__(self, channels, num_groups=32, eps=1e-6): + super().__init__() + self.channels = channels + self.eps = eps + self.heads = 1 + self.rescale_output_factor = 1.0 + + self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=eps, affine=True) + self.to_q = nn.Linear(channels, channels, bias=True) + self.to_k = nn.Linear(channels, channels, bias=True) + self.to_v = nn.Linear(channels, channels, bias=True) + self.to_out = nn.ModuleList([ + nn.Linear(channels, channels, bias=True), + nn.Dropout(0.0), + ]) + + def forward(self, hidden_states): + residual = hidden_states + + # Group norm + hidden_states = self.group_norm(hidden_states) + + # Flatten spatial dims: (B, C, H, W) -> (B, H*W, C) + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # QKV projection + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # Reshape for attention: (B, seq, dim) -> (B, heads, seq, head_dim) + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # Scaled dot-product attention + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + # Reshape back: (B, heads, seq, head_dim) -> (B, seq, heads*head_dim) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # Output projection + dropout + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + # Reshape back to 4D and add residual + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + hidden_states = hidden_states + residual + + # Rescale output factor + hidden_states = hidden_states / self.rescale_output_factor + + return hidden_states + + +class Downsample2D(nn.Module): + """Downsampling layer matching diffusers Downsample2D with use_conv=True. + Key names: conv.weight/bias. + When padding=0, applies explicit F.pad before conv to match dimension. + """ + def __init__(self, in_channels, out_channels, padding=1): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) + self.padding = padding + + def forward(self, hidden_states): + if self.padding == 0: + import torch.nn.functional as F + hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) + return self.conv(hidden_states) + + +class Upsample2D(nn.Module): + """Upsampling layer with key names matching diffusers checkpoint: conv.weight/bias.""" + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states): + hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + return self.conv(hidden_states) + + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + double_z=True, + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + self.down_blocks = nn.ModuleList([]) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + down_block = DownEncoderBlock2D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_downsample=not is_final_block, + downsample_padding=0, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + def forward(self, sample): + sample = self.conv_in(sample) + for down_block in self.down_blocks: + sample = down_block(sample) + sample = self.mid_block(sample) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + + +class Decoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + norm_type="group", + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) + + self.up_blocks = nn.ModuleList([]) + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + up_block = UpDecoderBlock2D( + in_channels=prev_output_channel, + out_channels=output_channel, + num_layers=self.layers_per_block + 1, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_upsample=not is_final_block, + temb_channels=temb_channels, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + def forward(self, sample, latent_embeds=None): + sample = self.conv_in(sample) + sample = self.mid_block(sample, latent_embeds) + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + return sample + + +class StableDiffusionVAE(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), + up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), + block_out_channels=(128, 256, 512, 512), + layers_per_block=2, + act_fn="silu", + latent_channels=4, + norm_num_groups=32, + sample_size=512, + scaling_factor=0.18215, + shift_factor=None, + latents_mean=None, + latents_std=None, + force_upcast=True, + use_quant_conv=True, + use_post_quant_conv=True, + mid_block_add_attention=True, + ): + super().__init__() + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + self.latents_mean = latents_mean + self.latents_std = latents_std + self.scaling_factor = scaling_factor + self.shift_factor = shift_factor + self.sample_size = sample_size + self.force_upcast = force_upcast + + def _encode(self, x): + h = self.encoder(x) + if self.quant_conv is not None: + h = self.quant_conv(h) + return h + + def encode(self, x): + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + return posterior + + def _decode(self, z): + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + return self.decoder(z) + + def decode(self, z): + return self._decode(z) + + def forward(self, sample, sample_posterior=True, return_dict=True, generator=None): + posterior = self.encode(sample) + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + # Scale latent + z = z * self.scaling_factor + decode = self.decode(z) + if return_dict: + return {"sample": decode, "posterior": posterior, "latent_sample": z} + return decode, posterior diff --git a/diffsynth/models/stable_diffusion_xl_text_encoder.py b/diffsynth/models/stable_diffusion_xl_text_encoder.py new file mode 100644 index 0000000..559eb10 --- /dev/null +++ b/diffsynth/models/stable_diffusion_xl_text_encoder.py @@ -0,0 +1,62 @@ +import torch + + +class SDXLTextEncoder2(torch.nn.Module): + def __init__( + self, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + vocab_size=49408, + layer_norm_eps=1e-05, + hidden_act="gelu", + initializer_factor=1.0, + initializer_range=0.02, + bos_token_id=0, + eos_token_id=2, + pad_token_id=1, + projection_dim=1280, + ): + super().__init__() + from transformers import CLIPTextConfig, CLIPTextModelWithProjection + + config = CLIPTextConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + max_position_embeddings=max_position_embeddings, + vocab_size=vocab_size, + layer_norm_eps=layer_norm_eps, + hidden_act=hidden_act, + initializer_factor=initializer_factor, + initializer_range=initializer_range, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + projection_dim=projection_dim, + ) + self.model = CLIPTextModelWithProjection(config) + self.config = config + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_hidden_states=True, + **kwargs, + ): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=output_hidden_states, + return_dict=True, + **kwargs, + ) + if output_hidden_states: + return outputs.text_embeds, outputs.hidden_states + return outputs.text_embeds diff --git a/diffsynth/models/stable_diffusion_xl_unet.py b/diffsynth/models/stable_diffusion_xl_unet.py new file mode 100644 index 0000000..a13a74e --- /dev/null +++ b/diffsynth/models/stable_diffusion_xl_unet.py @@ -0,0 +1,922 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from typing import Optional + + +# ===== Time Embedding ===== + +class Timesteps(nn.Module): + def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.freq_shift = freq_shift + + def forward(self, timesteps): + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / half_dim + self.freq_shift + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + if self.flip_sin_to_cos: + emb = torch.cat([cos_emb, sin_emb], dim=-1) + else: + emb = torch.cat([sin_emb, cos_emb], dim=-1) + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None): + super().__init__() + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + self.act = nn.SiLU() if act_fn == "silu" else nn.GELU() + out_dim = out_dim if out_dim is not None else time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, out_dim) + + def forward(self, sample): + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + + +# ===== ResNet Blocks ===== + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps) + self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels)) + + self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = nn.SiLU() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + elif non_linearity == "gelu": + self.nonlinearity = nn.GELU() + elif non_linearity == "relu": + self.nonlinearity = nn.ReLU() + + self.use_conv_shortcut = conv_shortcut + self.conv_shortcut = None + if conv_shortcut: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None + + def forward(self, input_tensor, temb=None): + hidden_states = input_tensor + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1) + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + return output_tensor + + +# ===== Transformer Blocks ===== + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, dropout=0.0): + super().__init__() + self.net = nn.ModuleList([ + GEGLU(dim, dim * 4), + nn.Dropout(dropout), + nn.Linear(dim * 4, dim if dim_out is None else dim_out), + ]) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class Attention(nn.Module): + def __init__( + self, + query_dim, + heads=8, + dim_head=64, + dropout=0.0, + bias=False, + upcast_attention=False, + cross_attention_dim=None, + ): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.inner_dim = inner_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias) + self.to_out = nn.ModuleList([ + nn.Linear(inner_dim, query_dim, bias=True), + nn.Dropout(dropout), + ]) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + query = self.to_q(hidden_states) + batch_size, seq_len, _ = query.shape + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + head_dim = self.inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + cross_attention_dim=None, + upcast_attention=False, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn1 = Attention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + ) + self.norm2 = nn.LayerNorm(dim) + self.attn2 = Attention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + cross_attention_dim=cross_attention_dim, + ) + self.norm3 = nn.LayerNorm(dim) + self.ff = FeedForward(dim, dropout=dropout) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + attn_output = self.attn1(self.norm1(hidden_states)) + hidden_states = attn_output + hidden_states + attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + ff_output = self.ff(self.norm3(hidden_states)) + hidden_states = ff_output + hidden_states + return hidden_states + + +class Transformer2DModel(nn.Module): + def __init__( + self, + num_attention_heads=16, + attention_head_dim=64, + in_channels=320, + num_layers=1, + dropout=0.0, + norm_num_groups=32, + cross_attention_dim=768, + upcast_attention=False, + use_linear_projection=False, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.use_linear_projection = use_linear_projection + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6) + + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim, bias=True) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, bias=True) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + dim=inner_dim, + n_heads=num_attention_heads, + d_head=attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + for _ in range(num_layers) + ]) + + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels, bias=True) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, bias=True) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch, channel, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + + if self.use_linear_projection: + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel) + hidden_states = self.proj_in(hidden_states) + else: + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + + if self.use_linear_projection: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous() + else: + hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states + + +# ===== Down/Up Blocks ===== + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + temb_channels=1280, + dropout=0.0, + num_layers=1, + transformer_layers_per_block=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + cross_attention_dim=768, + attention_head_dim=1, + downsample=True, + use_linear_projection=False, + ): + super().__init__() + self.has_cross_attention = True + + resnets = [] + attentions = [] + + for i in range(num_layers): + in_channels_i = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels_i, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Transformer2DModel( + num_attention_heads=attention_head_dim, + attention_head_dim=out_channels // attention_head_dim, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + dropout=dropout, + norm_num_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if downsample: + self.downsamplers = nn.ModuleList([ + Downsample2D(out_channels, out_channels, padding=1) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + output_states.append(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states.append(hidden_states) + + return hidden_states, tuple(output_states) + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + temb_channels=1280, + dropout=0.0, + num_layers=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + downsample=True, + ): + super().__init__() + self.has_cross_attention = False + + resnets = [] + for i in range(num_layers): + in_channels_i = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels_i, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if downsample: + self.downsamplers = nn.ModuleList([ + Downsample2D(out_channels, out_channels, padding=1) + ]) + else: + self.downsamplers = None + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = [] + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states.append(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states.append(hidden_states) + + return hidden_states, tuple(output_states) + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + prev_output_channel, + temb_channels=1280, + dropout=0.0, + num_layers=1, + transformer_layers_per_block=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + cross_attention_dim=768, + attention_head_dim=1, + upsample=True, + use_linear_projection=False, + ): + super().__init__() + self.has_cross_attention = True + + resnets = [] + attentions = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Transformer2DModel( + num_attention_heads=attention_head_dim, + attention_head_dim=out_channels // attention_head_dim, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + dropout=dropout, + norm_num_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if upsample: + self.upsamplers = nn.ModuleList([ + Upsample2D(out_channels, out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None): + for resnet, attn in zip(self.resnets, self.attentions): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size=upsample_size) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + prev_output_channel, + temb_channels=1280, + dropout=0.0, + num_layers=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + upsample=True, + ): + super().__init__() + self.has_cross_attention = False + + resnets = [] + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if upsample: + self.upsamplers = nn.ModuleList([ + Upsample2D(out_channels, out_channels) + ]) + else: + self.upsamplers = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None): + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size=upsample_size) + + return hidden_states + + +# ===== UNet Mid Block ===== + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels, + temb_channels=1280, + dropout=0.0, + num_layers=1, + transformer_layers_per_block=1, + resnet_eps=1e-6, + resnet_time_scale_shift="default", + resnet_act_fn="swish", + resnet_groups=32, + resnet_pre_norm=True, + cross_attention_dim=768, + attention_head_dim=1, + use_linear_projection=False, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + num_attention_heads=attention_head_dim, + attention_head_dim=in_channels // attention_head_dim, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + dropout=dropout, + norm_num_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=1.0, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = resnet(hidden_states, temb) + return hidden_states + + +# ===== Downsample / Upsample ===== + +class Downsample2D(nn.Module): + def __init__(self, in_channels, out_channels, padding=1): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding) + self.padding = padding + + def forward(self, hidden_states): + if self.padding == 0: + hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) + return self.conv(hidden_states) + + +class Upsample2D(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states, upsample_size=None): + if upsample_size is not None: + hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + return self.conv(hidden_states) + + +# ===== SDXL UNet2DConditionModel ===== + +class SDXLUNet2DConditionModel(nn.Module): + def __init__( + self, + sample_size=128, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"), + block_out_channels=(320, 640, 1280), + layers_per_block=2, + cross_attention_dim=2048, + attention_head_dim=5, + transformer_layers_per_block=1, + norm_num_groups=32, + norm_eps=1e-5, + dropout=0.0, + act_fn="silu", + time_embedding_type="positional", + flip_sin_to_cos=True, + freq_shift=0, + time_embedding_dim=None, + resnet_time_scale_shift="default", + upcast_attention=False, + use_linear_projection=False, + addition_embed_type=None, + addition_time_embed_dim=None, + projection_class_embeddings_input_dim=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.sample_size = sample_size + self.addition_embed_type = addition_embed_type + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + timestep_embedding_dim = time_embedding_dim or block_out_channels[0] + self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift) + time_embed_dim = block_out_channels[0] * 4 + self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim) + + if addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.down_blocks = nn.ModuleList() + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if "CrossAttn" in down_block_type: + down_block = CrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + dropout=dropout, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim[i], + downsample=not is_final_block, + use_linear_projection=use_linear_projection, + ) + else: + down_block = DownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + dropout=dropout, + num_layers=layers_per_block, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample=not is_final_block, + ) + self.down_blocks.append(down_block) + + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + dropout=dropout, + num_layers=1, + transformer_layers_per_block=transformer_layers_per_block[-1], + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim[-1], + use_linear_projection=use_linear_projection, + ) + + self.up_blocks = nn.ModuleList() + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + output_channel = reversed_block_out_channels[0] + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + if "CrossAttn" in up_block_type: + up_block = CrossAttnUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + dropout=dropout, + num_layers=layers_per_block + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=reversed_attention_head_dim[i], + upsample=not is_final_block, + use_linear_projection=use_linear_projection, + ) + else: + up_block = UpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + dropout=dropout, + num_layers=layers_per_block + 1, + resnet_eps=norm_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True): + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + t_emb = self.time_proj(timesteps) + t_emb = t_emb.to(dtype=sample.dtype) + emb = self.time_embedding(t_emb) + + if self.addition_embed_type == "text_time": + text_embeds = added_cond_kwargs.get("text_embeds") + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + emb = emb + aug_emb + + sample = self.conv_in(sample) + + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + sample, res_samples = down_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + down_block_res_samples += res_samples + + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + for up_block in self.up_blocks: + res_samples = down_block_res_samples[-len(up_block.resnets):] + down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)] + + upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None + sample = up_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + return sample diff --git a/diffsynth/pipelines/stable_diffusion.py b/diffsynth/pipelines/stable_diffusion.py new file mode 100644 index 0000000..bb8e298 --- /dev/null +++ b/diffsynth/pipelines/stable_diffusion.py @@ -0,0 +1,222 @@ +import torch +from PIL import Image +from tqdm import tqdm +from typing import Union + +from ..core.device.npu_compatible_device import get_device_type +from ..diffusion.ddim_scheduler import DDIMScheduler +from ..core import ModelConfig +from ..diffusion.base_pipeline import BasePipeline, PipelineUnit + +from transformers import AutoTokenizer, CLIPTextModel +from ..models.stable_diffusion_text_encoder import SDTextEncoder +from ..models.stable_diffusion_unet import UNet2DConditionModel +from ..models.stable_diffusion_vae import StableDiffusionVAE + + +class StableDiffusionPipeline(BasePipeline): + + def __init__(self, device=get_device_type(), torch_dtype=torch.float16): + super().__init__( + device=device, torch_dtype=torch_dtype, + height_division_factor=8, width_division_factor=8, + ) + self.scheduler = DDIMScheduler() + self.text_encoder: SDTextEncoder = None + self.unet: UNet2DConditionModel = None + self.vae: StableDiffusionVAE = None + self.tokenizer: AutoTokenizer = None + + self.in_iteration_models = ("unet",) + self.units = [ + SDUnit_ShapeChecker(), + SDUnit_PromptEmbedder(), + SDUnit_NoiseInitializer(), + SDUnit_InputImageEmbedder(), + ] + self.model_fn = model_fn_stable_diffusion + self.compilable_models = ["unet"] + + @staticmethod + def from_pretrained( + torch_dtype: torch.dtype = torch.float16, + device: Union[str, torch.device] = get_device_type(), + model_configs: list[ModelConfig] = [], + tokenizer_config: ModelConfig = None, + vram_limit: float = None, + ): + pipe = StableDiffusionPipeline(device=device, torch_dtype=torch_dtype) + # Override vram_config to use the specified torch_dtype for all models + for mc in model_configs: + mc._vram_config_override = { + 'onload_dtype': torch_dtype, + 'computation_dtype': torch_dtype, + } + model_pool = pipe.download_and_load_models(model_configs, vram_limit) + pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder") + pipe.unet = model_pool.fetch_model("stable_diffusion_unet") + pipe.vae = model_pool.fetch_model("stable_diffusion_vae") + if tokenizer_config is not None: + tokenizer_config.download_if_necessary() + pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path) + pipe.vram_management_enabled = pipe.check_vram_management_state() + return pipe + + @torch.no_grad() + def __call__( + self, + prompt: str, + negative_prompt: str = "", + cfg_scale: float = 7.5, + height: int = 512, + width: int = 512, + seed: int = None, + rand_device: str = "cpu", + num_inference_steps: int = 50, + eta: float = 0.0, + guidance_rescale: float = 0.0, + progress_bar_cmd=tqdm, + ): + # 1. Scheduler + self.scheduler.set_timesteps( + num_inference_steps, eta=eta, + ) + + # 2. Three-dict input preparation + inputs_posi = {"prompt": prompt} + inputs_nega = {"negative_prompt": negative_prompt} + inputs_shared = { + "cfg_scale": cfg_scale, + "height": height, "width": width, + "seed": seed, "rand_device": rand_device, + "guidance_rescale": guidance_rescale, + } + + # 3. Unit chain execution + for unit in self.units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner( + unit, self, inputs_shared, inputs_posi, inputs_nega + ) + + # 4. Denoise loop + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, + inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["latents"] = self.step( + self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared + ) + + # 5. VAE decode + self.load_models_to_device(['vae']) + latents = inputs_shared["latents"] / self.vae.scaling_factor + image = self.vae.decode(latents) + image = self.vae_output_to_image(image) + self.load_models_to_device([]) + + return image + + +class SDUnit_ShapeChecker(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width"), + output_params=("height", "width"), + ) + + def process(self, pipe: StableDiffusionPipeline, height, width): + height, width = pipe.check_resize_height_width(height, width) + return {"height": height, "width": width} + + +class SDUnit_PromptEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + seperate_cfg=True, + input_params_posi={"prompt": "prompt"}, + input_params_nega={"prompt": "negative_prompt"}, + output_params=("prompt_embeds",), + onload_model_names=("text_encoder",) + ) + + def encode_prompt( + self, + pipe: StableDiffusionPipeline, + prompt: str, + device: torch.device, + ) -> torch.Tensor: + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + prompt_embeds = pipe.text_encoder(text_input_ids) + # TextEncoder returns (last_hidden_state, hidden_states) or just last_hidden_state. + # last_hidden_state is the post-final-layer-norm output, matching diffusers encode_prompt. + if isinstance(prompt_embeds, tuple): + prompt_embeds = prompt_embeds[0] + return prompt_embeds + + def process(self, pipe: StableDiffusionPipeline, prompt): + pipe.load_models_to_device(self.onload_model_names) + prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device) + return {"prompt_embeds": prompt_embeds} + + +class SDUnit_NoiseInitializer(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("height", "width", "seed", "rand_device"), + output_params=("noise",), + ) + + def process(self, pipe: StableDiffusionPipeline, height, width, seed, rand_device): + noise = pipe.generate_noise( + (1, pipe.unet.in_channels, height // 8, width // 8), + seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype + ) + return {"noise": noise} + + +class SDUnit_InputImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("noise",), + output_params=("latents",), + ) + + def process(self, pipe: StableDiffusionPipeline, noise): + # For Text-to-Image, latents = noise (scaled by scheduler) + latents = noise * pipe.scheduler.init_noise_sigma + return {"latents": latents} + + +def model_fn_stable_diffusion( + unet: UNet2DConditionModel, + latents=None, + timestep=None, + prompt_embeds=None, + cross_attention_kwargs=None, + timestep_cond=None, + added_cond_kwargs=None, + **kwargs, +): + # SD timestep is already in 0-999 range, no scaling needed + noise_pred = unet( + latents, + timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + timestep_cond=timestep_cond, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + return noise_pred diff --git a/diffsynth/utils/state_dict_converters/stable_diffusion_text_encoder.py b/diffsynth/utils/state_dict_converters/stable_diffusion_text_encoder.py new file mode 100644 index 0000000..3086b65 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/stable_diffusion_text_encoder.py @@ -0,0 +1,7 @@ +def SDTextEncoderStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + if key.startswith("text_model.") and "position_ids" not in key: + new_key = "model." + key + new_state_dict[new_key] = state_dict[key] + return new_state_dict diff --git a/diffsynth/utils/state_dict_converters/stable_diffusion_vae.py b/diffsynth/utils/state_dict_converters/stable_diffusion_vae.py new file mode 100644 index 0000000..a41d3ce --- /dev/null +++ b/diffsynth/utils/state_dict_converters/stable_diffusion_vae.py @@ -0,0 +1,18 @@ +def SDVAEStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + if ".query." in key: + new_key = key.replace(".query.", ".to_q.") + new_state_dict[new_key] = state_dict[key] + elif ".key." in key: + new_key = key.replace(".key.", ".to_k.") + new_state_dict[new_key] = state_dict[key] + elif ".value." in key: + new_key = key.replace(".value.", ".to_v.") + new_state_dict[new_key] = state_dict[key] + elif ".proj_attn." in key: + new_key = key.replace(".proj_attn.", ".to_out.0.") + new_state_dict[new_key] = state_dict[key] + else: + new_state_dict[key] = state_dict[key] + return new_state_dict diff --git a/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py b/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py new file mode 100644 index 0000000..789decb --- /dev/null +++ b/diffsynth/utils/state_dict_converters/stable_diffusion_xl_text_encoder.py @@ -0,0 +1,7 @@ +def SDXLTextEncoder2StateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + if key.startswith("text_model.") and "position_ids" not in key: + new_key = "model." + key + new_state_dict[new_key] = state_dict[key] + return new_state_dict diff --git a/examples/stable_diffusion/model_inference/StableDiffusion-T2I.py b/examples/stable_diffusion/model_inference/StableDiffusion-T2I.py new file mode 100644 index 0000000..ea09a44 --- /dev/null +++ b/examples/stable_diffusion/model_inference/StableDiffusion-T2I.py @@ -0,0 +1,25 @@ +import torch +from diffsynth.core import ModelConfig +from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + torch_dtype=torch.float32, + model_configs=[ + ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"), +) + +image = pipe( + prompt="a photo of an astronaut riding a horse on mars", + negative_prompt="", + cfg_scale=7.5, + height=512, + width=512, + seed=42, + num_inference_steps=50, +) +image.save("output_stable_diffusion_t2i.png") +print("Image saved to output_stable_diffusion_t2i.png") diff --git a/examples/stable_diffusion/model_inference_low_vram/StableDiffusion-T2I.py b/examples/stable_diffusion/model_inference_low_vram/StableDiffusion-T2I.py new file mode 100644 index 0000000..1947706 --- /dev/null +++ b/examples/stable_diffusion/model_inference_low_vram/StableDiffusion-T2I.py @@ -0,0 +1,37 @@ +import torch +from diffsynth.core import ModelConfig +from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline + +vram_config = { + "offload_dtype": torch.float32, + "offload_device": "cpu", + "onload_dtype": torch.float32, + "onload_device": "cpu", + "preparing_dtype": torch.float32, + "preparing_device": "cuda", + "computation_dtype": torch.float32, + "computation_device": "cuda", +} + +pipe = StableDiffusionPipeline.from_pretrained( + torch_dtype=torch.float32, + model_configs=[ + ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) + +image = pipe( + prompt="a photo of an astronaut riding a horse on mars", + negative_prompt="", + cfg_scale=7.5, + height=512, + width=512, + seed=42, + num_inference_steps=50, +) +image.save("output_stable_diffusion_t2i_low_vram.png") +print("Image saved to output_stable_diffusion_t2i_low_vram.png")