This commit is contained in:
mi804
2026-04-23 17:35:24 +08:00
parent 5c89a15b9a
commit 82e482286c
14 changed files with 3274 additions and 1 deletions

View File

@@ -0,0 +1,78 @@
import torch
class SDTextEncoder(torch.nn.Module):
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
vocab_size=49408,
layer_norm_eps=1e-05,
hidden_act="quick_gelu",
initializer_factor=1.0,
initializer_range=0.02,
bos_token_id=0,
eos_token_id=2,
pad_token_id=1,
projection_dim=768,
):
super().__init__()
from transformers import CLIPConfig, CLIPTextModel
config = CLIPConfig(
text_config={
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"num_hidden_layers": num_hidden_layers,
"num_attention_heads": num_attention_heads,
"max_position_embeddings": max_position_embeddings,
"vocab_size": vocab_size,
"layer_norm_eps": layer_norm_eps,
"hidden_act": hidden_act,
"initializer_factor": initializer_factor,
"initializer_range": initializer_range,
"bos_token_id": bos_token_id,
"eos_token_id": eos_token_id,
"pad_token_id": pad_token_id,
"projection_dim": projection_dim,
"dropout": 0.0,
},
vision_config={
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"num_hidden_layers": num_hidden_layers,
"num_attention_heads": num_attention_heads,
"max_position_embeddings": max_position_embeddings,
"layer_norm_eps": layer_norm_eps,
"hidden_act": hidden_act,
"initializer_factor": initializer_factor,
"initializer_range": initializer_range,
"projection_dim": projection_dim,
},
projection_dim=projection_dim,
)
self.model = CLIPTextModel(config.text_config)
self.config = config
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_hidden_states=True,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs,
)
if output_hidden_states:
return outputs.last_hidden_state, outputs.hidden_states
return outputs.last_hidden_state

View File

@@ -0,0 +1,912 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
# ===== Time Embedding =====
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.freq_shift = freq_shift
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / half_dim + self.freq_shift
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
if self.flip_sin_to_cos:
emb = torch.cat([cos_emb, sin_emb], dim=-1)
else:
emb = torch.cat([sin_emb, cos_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
out_dim = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
# ===== ResNet Blocks =====
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
# ===== Transformer Blocks =====
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, dropout=0.0):
super().__init__()
self.net = nn.ModuleList([
GEGLU(dim, dim * 4),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
])
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class Attention(nn.Module):
"""Attention block matching diffusers checkpoint key format.
Keys: to_q.weight, to_k.weight, to_v.weight, to_out.0.weight, to_out.0.bias
"""
def __init__(
self,
query_dim,
heads=8,
dim_head=64,
dropout=0.0,
bias=False,
upcast_attention=False,
cross_attention_dim=None,
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.inner_dim = inner_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([
nn.Linear(inner_dim, query_dim, bias=True),
nn.Dropout(dropout),
])
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
# Query
query = self.to_q(hidden_states)
batch_size, seq_len, _ = query.shape
# Key/Value
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
# Reshape for multi-head attention
head_dim = self.inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# Scaled dot-product attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Reshape back
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
hidden_states = hidden_states.to(query.dtype)
# Output projection
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
cross_attention_dim=None,
upcast_attention=False,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
)
self.norm2 = nn.LayerNorm(dim)
self.attn2 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim,
)
self.norm3 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
# Self-attention
attn_output = self.attn1(self.norm1(hidden_states))
hidden_states = attn_output + hidden_states
# Cross-attention
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# Feed-forward
ff_output = self.ff(self.norm3(hidden_states))
hidden_states = ff_output + hidden_states
return hidden_states
class Transformer2DModel(nn.Module):
"""2D Transformer block wrapper matching diffusers checkpoint structure.
Keys: norm.weight/bias, proj_in.weight/bias, transformer_blocks.X.*, proj_out.weight/bias
"""
def __init__(
self,
num_attention_heads=16,
attention_head_dim=64,
in_channels=320,
num_layers=1,
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
upcast_attention=False,
):
super().__init__()
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
self.proj_in = nn.Conv2d(in_channels, num_attention_heads * attention_head_dim, kernel_size=1, bias=True)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
dim=num_attention_heads * attention_head_dim,
n_heads=num_attention_heads,
d_head=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)
for _ in range(num_layers)
])
self.proj_out = nn.Conv2d(num_attention_heads * attention_head_dim, in_channels, kernel_size=1, bias=True)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
# Normalize and project to sequence
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
# Transformer blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
# Project back to 2D
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
# ===== Down/Up Blocks =====
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
downsample=True,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
downsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
upsample=True,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
# Pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
upsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet in self.resnets:
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
# ===== UNet Mid Block =====
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# There is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=in_channels // attention_head_dim,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
# ===== Downsample / Upsample =====
class Downsample2D(nn.Module):
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states, upsample_size=None):
if upsample_size is not None:
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
# ===== UNet2DConditionModel =====
class UNet2DConditionModel(nn.Module):
"""Stable Diffusion UNet with cross-attention conditioning.
state_dict keys match the diffusers UNet2DConditionModel checkpoint format.
"""
def __init__(
self,
sample_size=64,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
cross_attention_dim=768,
attention_head_dim=8,
norm_num_groups=32,
norm_eps=1e-5,
dropout=0.0,
act_fn="silu",
time_embedding_type="positional",
flip_sin_to_cos=True,
freq_shift=0,
time_embedding_dim=None,
resnet_time_scale_shift="default",
upcast_attention=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.sample_size = sample_size
# Time embedding
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
# Input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
# Down blocks
self.down_blocks = nn.ModuleList()
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if "CrossAttn" in down_block_type:
down_block = CrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
downsample=not is_final_block,
)
else:
down_block = DownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample=not is_final_block,
)
self.down_blocks.append(down_block)
# Mid block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
)
# Up blocks
self.up_blocks = nn.ModuleList()
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
# in_channels for up blocks: diffusers uses reversed_block_out_channels[min(i+1, len-1)]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
if "CrossAttn" in up_block_type:
up_block = CrossAttnUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
upsample=not is_final_block,
)
else:
up_block = UpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
upsample=not is_final_block,
)
self.up_blocks.append(up_block)
# Output
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
# 1. Time embedding
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
# 2. Pre-process
sample = self.conv_in(sample)
# 3. Down
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
sample, res_samples = down_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
down_block_res_samples += res_samples
# 4. Mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. Up
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-len(up_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
sample = up_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
)
# 6. Post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return sample

View File

@@ -0,0 +1,642 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Optional
class DiagonalGaussianDistribution:
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# randn_like doesn't accept generator on all torch versions
sample = torch.randn(self.mean.shape, generator=generator,
device=self.parameters.device, dtype=self.parameters.dtype)
return self.mean + self.std * sample
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
if self.deterministic:
return torch.tensor([0.0])
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3],
)
def mode(self) -> torch.Tensor:
return self.mean
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
else:
raise ValueError(f"Unsupported non_linearity: {non_linearity}")
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class DownEncoderBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=downsample_padding)
])
else:
self.downsamplers = None
def forward(self, hidden_states, *args, **kwargs):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class UpDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
output_scale_factor=1.0,
add_upsample=True,
temb_channels=None,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, temb=None):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetMidBlock2D(nn.Module):
def __init__(
self,
in_channels,
temb_channels=None,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
add_attention=True,
attention_head_dim=1,
output_scale_factor=1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
AttentionBlock(
in_channels,
num_groups=resnet_groups,
eps=resnet_eps,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class AttentionBlock(nn.Module):
"""Simple attention block for VAE mid block.
Mirrors diffusers Attention class with AttnProcessor2_0 for VAE use case.
Uses modern key names (to_q, to_k, to_v, to_out) matching in-memory diffusers structure.
Checkpoint uses deprecated keys (query, key, value, proj_attn) — mapped via converter.
"""
def __init__(self, channels, num_groups=32, eps=1e-6):
super().__init__()
self.channels = channels
self.eps = eps
self.heads = 1
self.rescale_output_factor = 1.0
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=eps, affine=True)
self.to_q = nn.Linear(channels, channels, bias=True)
self.to_k = nn.Linear(channels, channels, bias=True)
self.to_v = nn.Linear(channels, channels, bias=True)
self.to_out = nn.ModuleList([
nn.Linear(channels, channels, bias=True),
nn.Dropout(0.0),
])
def forward(self, hidden_states):
residual = hidden_states
# Group norm
hidden_states = self.group_norm(hidden_states)
# Flatten spatial dims: (B, C, H, W) -> (B, H*W, C)
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# QKV projection
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
# Reshape for attention: (B, seq, dim) -> (B, heads, seq, head_dim)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# Scaled dot-product attention
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Reshape back: (B, heads, seq, head_dim) -> (B, seq, heads*head_dim)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Output projection + dropout
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
# Reshape back to 4D and add residual
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
hidden_states = hidden_states + residual
# Rescale output factor
hidden_states = hidden_states / self.rescale_output_factor
return hidden_states
class Downsample2D(nn.Module):
"""Downsampling layer matching diffusers Downsample2D with use_conv=True.
Key names: conv.weight/bias.
When padding=0, applies explicit F.pad before conv to match dimension.
"""
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
import torch.nn.functional as F
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
"""Upsampling layer with key names matching diffusers checkpoint: conv.weight/bias."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states):
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
class Encoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.down_blocks = nn.ModuleList([])
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = DownEncoderBlock2D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_downsample=not is_final_block,
downsample_padding=0,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
def forward(self, sample):
sample = self.conv_in(sample)
for down_block in self.down_blocks:
sample = down_block(sample)
sample = self.mid_block(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class Decoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group",
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpDecoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_upsample=not is_final_block,
temb_channels=temb_channels,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(self, sample, latent_embeds=None):
sample = self.conv_in(sample)
sample = self.mid_block(sample, latent_embeds)
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class StableDiffusionVAE(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
block_out_channels=(128, 256, 512, 512),
layers_per_block=2,
act_fn="silu",
latent_channels=4,
norm_num_groups=32,
sample_size=512,
scaling_factor=0.18215,
shift_factor=None,
latents_mean=None,
latents_std=None,
force_upcast=True,
use_quant_conv=True,
use_post_quant_conv=True,
mid_block_add_attention=True,
):
super().__init__()
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
double_z=True,
mid_block_add_attention=mid_block_add_attention,
)
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
mid_block_add_attention=mid_block_add_attention,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
self.latents_mean = latents_mean
self.latents_std = latents_std
self.scaling_factor = scaling_factor
self.shift_factor = shift_factor
self.sample_size = sample_size
self.force_upcast = force_upcast
def _encode(self, x):
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
return h
def encode(self, x):
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
return posterior
def _decode(self, z):
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
return self.decoder(z)
def decode(self, z):
return self._decode(z)
def forward(self, sample, sample_posterior=True, return_dict=True, generator=None):
posterior = self.encode(sample)
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
# Scale latent
z = z * self.scaling_factor
decode = self.decode(z)
if return_dict:
return {"sample": decode, "posterior": posterior, "latent_sample": z}
return decode, posterior

View File

@@ -0,0 +1,62 @@
import torch
class SDXLTextEncoder2(torch.nn.Module):
def __init__(
self,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
vocab_size=49408,
layer_norm_eps=1e-05,
hidden_act="gelu",
initializer_factor=1.0,
initializer_range=0.02,
bos_token_id=0,
eos_token_id=2,
pad_token_id=1,
projection_dim=1280,
):
super().__init__()
from transformers import CLIPTextConfig, CLIPTextModelWithProjection
config = CLIPTextConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
max_position_embeddings=max_position_embeddings,
vocab_size=vocab_size,
layer_norm_eps=layer_norm_eps,
hidden_act=hidden_act,
initializer_factor=initializer_factor,
initializer_range=initializer_range,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
projection_dim=projection_dim,
)
self.model = CLIPTextModelWithProjection(config)
self.config = config
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_hidden_states=True,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs,
)
if output_hidden_states:
return outputs.text_embeds, outputs.hidden_states
return outputs.text_embeds

View File

@@ -0,0 +1,922 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
# ===== Time Embedding =====
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.freq_shift = freq_shift
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / half_dim + self.freq_shift
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
if self.flip_sin_to_cos:
emb = torch.cat([cos_emb, sin_emb], dim=-1)
else:
emb = torch.cat([sin_emb, cos_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
out_dim = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
# ===== ResNet Blocks =====
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
# ===== Transformer Blocks =====
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, dropout=0.0):
super().__init__()
self.net = nn.ModuleList([
GEGLU(dim, dim * 4),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
])
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class Attention(nn.Module):
def __init__(
self,
query_dim,
heads=8,
dim_head=64,
dropout=0.0,
bias=False,
upcast_attention=False,
cross_attention_dim=None,
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.inner_dim = inner_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([
nn.Linear(inner_dim, query_dim, bias=True),
nn.Dropout(dropout),
])
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
query = self.to_q(hidden_states)
batch_size, seq_len, _ = query.shape
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
head_dim = self.inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
cross_attention_dim=None,
upcast_attention=False,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
)
self.norm2 = nn.LayerNorm(dim)
self.attn2 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim,
)
self.norm3 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
attn_output = self.attn1(self.norm1(hidden_states))
hidden_states = attn_output + hidden_states
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
ff_output = self.ff(self.norm3(hidden_states))
hidden_states = ff_output + hidden_states
return hidden_states
class Transformer2DModel(nn.Module):
def __init__(
self,
num_attention_heads=16,
attention_head_dim=64,
in_channels=320,
num_layers=1,
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
upcast_attention=False,
use_linear_projection=False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.use_linear_projection = use_linear_projection
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim, bias=True)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, bias=True)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
dim=inner_dim,
n_heads=num_attention_heads,
d_head=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)
for _ in range(num_layers)
])
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels, bias=True)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, bias=True)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.use_linear_projection:
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
hidden_states = self.proj_in(hidden_states)
else:
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
else:
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
# ===== Down/Up Blocks =====
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
downsample=True,
use_linear_projection=False,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
downsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
upsample=True,
use_linear_projection=False,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
upsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet in self.resnets:
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
# ===== UNet Mid Block =====
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
use_linear_projection=False,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=in_channels // attention_head_dim,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
# ===== Downsample / Upsample =====
class Downsample2D(nn.Module):
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states, upsample_size=None):
if upsample_size is not None:
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
# ===== SDXL UNet2DConditionModel =====
class SDXLUNet2DConditionModel(nn.Module):
def __init__(
self,
sample_size=128,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
block_out_channels=(320, 640, 1280),
layers_per_block=2,
cross_attention_dim=2048,
attention_head_dim=5,
transformer_layers_per_block=1,
norm_num_groups=32,
norm_eps=1e-5,
dropout=0.0,
act_fn="silu",
time_embedding_type="positional",
flip_sin_to_cos=True,
freq_shift=0,
time_embedding_dim=None,
resnet_time_scale_shift="default",
upcast_attention=False,
use_linear_projection=False,
addition_embed_type=None,
addition_time_embed_dim=None,
projection_class_embeddings_input_dim=None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.sample_size = sample_size
self.addition_embed_type = addition_embed_type
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
if addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
self.down_blocks = nn.ModuleList()
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if "CrossAttn" in down_block_type:
down_block = CrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim[i],
downsample=not is_final_block,
use_linear_projection=use_linear_projection,
)
else:
down_block = DownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample=not is_final_block,
)
self.down_blocks.append(down_block)
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=1,
transformer_layers_per_block=transformer_layers_per_block[-1],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim[-1],
use_linear_projection=use_linear_projection,
)
self.up_blocks = nn.ModuleList()
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
if "CrossAttn" in up_block_type:
up_block = CrossAttnUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=reversed_attention_head_dim[i],
upsample=not is_final_block,
use_linear_projection=use_linear_projection,
)
else:
up_block = UpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
upsample=not is_final_block,
)
self.up_blocks.append(up_block)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
if self.addition_embed_type == "text_time":
text_embeds = added_cond_kwargs.get("text_embeds")
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb
sample = self.conv_in(sample)
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
sample, res_samples = down_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
down_block_res_samples += res_samples
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-len(up_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
sample = up_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return sample