# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from typing import Optional class DiagonalGaussianDistribution: def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like( self.mean, device=self.parameters.device, dtype=self.parameters.dtype ) def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: # randn_like doesn't accept generator on all torch versions sample = torch.randn(self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype) return self.mean + self.std * sample def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor: if self.deterministic: return torch.tensor([0.0]) if other is None: return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3], ) return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def mode(self) -> torch.Tensor: return self.mean class ResnetBlock2D(nn.Module): def __init__( self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, groups_out=None, pre_norm=True, eps=1e-6, non_linearity="swish", time_embedding_norm="default", output_scale_factor=1.0, use_in_shortcut=None, ): super().__init__() self.pre_norm = pre_norm self.time_embedding_norm = time_embedding_norm self.output_scale_factor = output_scale_factor if groups_out is None: groups_out = groups self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps) self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels)) self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps) self.dropout = nn.Dropout(dropout) self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1) if non_linearity == "swish": self.nonlinearity = nn.SiLU() elif non_linearity == "silu": self.nonlinearity = nn.SiLU() elif non_linearity == "gelu": self.nonlinearity = nn.GELU() elif non_linearity == "relu": self.nonlinearity = nn.ReLU() else: raise ValueError(f"Unsupported non_linearity: {non_linearity}") self.use_conv_shortcut = conv_shortcut self.conv_shortcut = None if conv_shortcut: self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) else: self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None def forward(self, input_tensor, temb=None): hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: temb = self.nonlinearity(temb) temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1) if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) if temb is not None and self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor class DownEncoderBlock2D(nn.Module): def __init__( self, in_channels, out_channels, dropout=0.0, num_layers=1, resnet_eps=1e-6, resnet_time_scale_shift="default", resnet_act_fn="swish", resnet_groups=32, resnet_pre_norm=True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels_i = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels_i, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList([ Downsample2D(out_channels, out_channels, padding=downsample_padding) ]) else: self.downsamplers = None def forward(self, hidden_states, *args, **kwargs): for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class UpDecoderBlock2D(nn.Module): def __init__( self, in_channels, out_channels, dropout=0.0, num_layers=1, resnet_eps=1e-6, resnet_time_scale_shift="default", resnet_act_fn="swish", resnet_groups=32, resnet_pre_norm=True, output_scale_factor=1.0, add_upsample=True, temb_channels=None, ): super().__init__() resnets = [] for i in range(num_layers): in_channels_i = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels_i, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([ Upsample2D(out_channels, out_channels) ]) else: self.upsamplers = None def forward(self, hidden_states, temb=None): for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class UNetMidBlock2D(nn.Module): def __init__( self, in_channels, temb_channels=None, dropout=0.0, num_layers=1, resnet_eps=1e-6, resnet_time_scale_shift="default", resnet_act_fn="swish", resnet_groups=32, resnet_pre_norm=True, add_attention=True, attention_head_dim=1, output_scale_factor=1.0, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] if attention_head_dim is None: attention_head_dim = in_channels for _ in range(num_layers): if self.add_attention: attentions.append( AttentionBlock( in_channels, num_groups=resnet_groups, eps=resnet_eps, ) ) else: attentions.append(None) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, temb) return hidden_states class AttentionBlock(nn.Module): """Simple attention block for VAE mid block. Mirrors diffusers Attention class with AttnProcessor2_0 for VAE use case. Uses modern key names (to_q, to_k, to_v, to_out) matching in-memory diffusers structure. Checkpoint uses deprecated keys (query, key, value, proj_attn) — mapped via converter. """ def __init__(self, channels, num_groups=32, eps=1e-6): super().__init__() self.channels = channels self.eps = eps self.heads = 1 self.rescale_output_factor = 1.0 self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=eps, affine=True) self.to_q = nn.Linear(channels, channels, bias=True) self.to_k = nn.Linear(channels, channels, bias=True) self.to_v = nn.Linear(channels, channels, bias=True) self.to_out = nn.ModuleList([ nn.Linear(channels, channels, bias=True), nn.Dropout(0.0), ]) def forward(self, hidden_states): residual = hidden_states # Group norm hidden_states = self.group_norm(hidden_states) # Flatten spatial dims: (B, C, H, W) -> (B, H*W, C) batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) # QKV projection query = self.to_q(hidden_states) key = self.to_k(hidden_states) value = self.to_v(hidden_states) # Reshape for attention: (B, seq, dim) -> (B, heads, seq, head_dim) inner_dim = key.shape[-1] head_dim = inner_dim // self.heads query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) # Scaled dot-product attention hidden_states = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False ) # Reshape back: (B, heads, seq, head_dim) -> (B, seq, heads*head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # Output projection + dropout hidden_states = self.to_out[0](hidden_states) hidden_states = self.to_out[1](hidden_states) # Reshape back to 4D and add residual hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) hidden_states = hidden_states + residual # Rescale output factor hidden_states = hidden_states / self.rescale_output_factor return hidden_states class Downsample2D(nn.Module): """Downsampling layer matching diffusers Downsample2D with use_conv=True. Key names: conv.weight/bias. When padding=0, applies explicit F.pad before conv to match dimension. """ def __init__(self, in_channels, out_channels, padding=1): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) self.padding = padding def forward(self, hidden_states): if self.padding == 0: import torch.nn.functional as F hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) return self.conv(hidden_states) class Upsample2D(nn.Module): """Upsampling layer with key names matching diffusers checkpoint: conv.weight/bias.""" def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) def forward(self, hidden_states): hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") return self.conv(hidden_states) class Encoder(nn.Module): def __init__( self, in_channels=3, out_channels=3, down_block_types=("DownEncoderBlock2D",), block_out_channels=(64,), layers_per_block=2, norm_num_groups=32, act_fn="silu", double_z=True, mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) self.down_blocks = nn.ModuleList([]) output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = DownEncoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers=self.layers_per_block, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, add_downsample=not is_final_block, downsample_padding=0, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, add_attention=mid_block_add_attention, ) # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) def forward(self, sample): sample = self.conv_in(sample) for down_block in self.down_blocks: sample = down_block(sample) sample = self.mid_block(sample) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class Decoder(nn.Module): def __init__( self, in_channels=3, out_channels=3, up_block_types=("UpDecoderBlock2D",), block_out_channels=(64,), layers_per_block=2, norm_num_groups=32, act_fn="silu", norm_type="group", mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, add_attention=mid_block_add_attention, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = UpDecoderBlock2D( in_channels=prev_output_channel, out_channels=output_channel, num_layers=self.layers_per_block + 1, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, add_upsample=not is_final_block, temb_channels=temb_channels, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) def forward(self, sample, latent_embeds=None): sample = self.conv_in(sample) sample = self.mid_block(sample, latent_embeds) for up_block in self.up_blocks: sample = up_block(sample, latent_embeds) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class StableDiffusionVAE(nn.Module): def __init__( self, in_channels=3, out_channels=3, down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), block_out_channels=(128, 256, 512, 512), layers_per_block=2, act_fn="silu", latent_channels=4, norm_num_groups=32, sample_size=512, scaling_factor=0.18215, shift_factor=None, latents_mean=None, latents_std=None, force_upcast=True, use_quant_conv=True, use_post_quant_conv=True, mid_block_add_attention=True, ): super().__init__() self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, double_z=True, mid_block_add_attention=mid_block_add_attention, ) self.decoder = Decoder( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, mid_block_add_attention=mid_block_add_attention, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None self.latents_mean = latents_mean self.latents_std = latents_std self.scaling_factor = scaling_factor self.shift_factor = shift_factor self.sample_size = sample_size self.force_upcast = force_upcast def _encode(self, x): h = self.encoder(x) if self.quant_conv is not None: h = self.quant_conv(h) return h def encode(self, x): h = self._encode(x) posterior = DiagonalGaussianDistribution(h) return posterior def _decode(self, z): if self.post_quant_conv is not None: z = self.post_quant_conv(z) return self.decoder(z) def decode(self, z): return self._decode(z) def forward(self, sample, sample_posterior=True, return_dict=True, generator=None): posterior = self.encode(sample) if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() # Scale latent z = z * self.scaling_factor decode = self.decode(z) if return_dict: return {"sample": decode, "posterior": posterior, "latent_sample": z} return decode, posterior