mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 06:46:13 +00:00
923 lines
34 KiB
Python
923 lines
34 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
from typing import Optional
|
|
|
|
|
|
# ===== Time Embedding =====
|
|
|
|
class Timesteps(nn.Module):
|
|
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
|
|
super().__init__()
|
|
self.num_channels = num_channels
|
|
self.flip_sin_to_cos = flip_sin_to_cos
|
|
self.freq_shift = freq_shift
|
|
|
|
def forward(self, timesteps):
|
|
half_dim = self.num_channels // 2
|
|
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
|
exponent = exponent / half_dim + self.freq_shift
|
|
emb = torch.exp(exponent)
|
|
emb = timesteps[:, None].float() * emb[None, :]
|
|
sin_emb = torch.sin(emb)
|
|
cos_emb = torch.cos(emb)
|
|
if self.flip_sin_to_cos:
|
|
emb = torch.cat([cos_emb, sin_emb], dim=-1)
|
|
else:
|
|
emb = torch.cat([sin_emb, cos_emb], dim=-1)
|
|
return emb
|
|
|
|
|
|
class TimestepEmbedding(nn.Module):
|
|
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
|
|
super().__init__()
|
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
|
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
|
|
out_dim = out_dim if out_dim is not None else time_embed_dim
|
|
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
|
|
|
|
def forward(self, sample):
|
|
sample = self.linear_1(sample)
|
|
sample = self.act(sample)
|
|
sample = self.linear_2(sample)
|
|
return sample
|
|
|
|
|
|
# ===== ResNet Blocks =====
|
|
|
|
class ResnetBlock2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels=None,
|
|
conv_shortcut=False,
|
|
dropout=0.0,
|
|
temb_channels=512,
|
|
groups=32,
|
|
groups_out=None,
|
|
pre_norm=True,
|
|
eps=1e-6,
|
|
non_linearity="swish",
|
|
time_embedding_norm="default",
|
|
output_scale_factor=1.0,
|
|
use_in_shortcut=None,
|
|
):
|
|
super().__init__()
|
|
self.pre_norm = pre_norm
|
|
self.time_embedding_norm = time_embedding_norm
|
|
self.output_scale_factor = output_scale_factor
|
|
|
|
if groups_out is None:
|
|
groups_out = groups
|
|
|
|
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
|
|
|
if temb_channels is not None:
|
|
if self.time_embedding_norm == "default":
|
|
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
|
elif self.time_embedding_norm == "scale_shift":
|
|
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
|
|
|
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
|
|
|
if non_linearity == "swish":
|
|
self.nonlinearity = nn.SiLU()
|
|
elif non_linearity == "silu":
|
|
self.nonlinearity = nn.SiLU()
|
|
elif non_linearity == "gelu":
|
|
self.nonlinearity = nn.GELU()
|
|
elif non_linearity == "relu":
|
|
self.nonlinearity = nn.ReLU()
|
|
|
|
self.use_conv_shortcut = conv_shortcut
|
|
self.conv_shortcut = None
|
|
if conv_shortcut:
|
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
|
else:
|
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
|
|
|
def forward(self, input_tensor, temb=None):
|
|
hidden_states = input_tensor
|
|
hidden_states = self.norm1(hidden_states)
|
|
hidden_states = self.nonlinearity(hidden_states)
|
|
hidden_states = self.conv1(hidden_states)
|
|
|
|
if temb is not None:
|
|
temb = self.nonlinearity(temb)
|
|
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
|
|
|
if temb is not None and self.time_embedding_norm == "default":
|
|
hidden_states = hidden_states + temb
|
|
|
|
hidden_states = self.norm2(hidden_states)
|
|
|
|
if temb is not None and self.time_embedding_norm == "scale_shift":
|
|
scale, shift = torch.chunk(temb, 2, dim=1)
|
|
hidden_states = hidden_states * (1 + scale) + shift
|
|
|
|
hidden_states = self.nonlinearity(hidden_states)
|
|
hidden_states = self.dropout(hidden_states)
|
|
hidden_states = self.conv2(hidden_states)
|
|
|
|
if self.conv_shortcut is not None:
|
|
input_tensor = self.conv_shortcut(input_tensor)
|
|
|
|
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
|
return output_tensor
|
|
|
|
|
|
# ===== Transformer Blocks =====
|
|
|
|
class GEGLU(nn.Module):
|
|
def __init__(self, dim_in, dim_out):
|
|
super().__init__()
|
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
|
return hidden_states * F.gelu(gate)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim, dim_out=None, dropout=0.0):
|
|
super().__init__()
|
|
self.net = nn.ModuleList([
|
|
GEGLU(dim, dim * 4),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
|
|
])
|
|
|
|
def forward(self, hidden_states):
|
|
for module in self.net:
|
|
hidden_states = module(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
query_dim,
|
|
heads=8,
|
|
dim_head=64,
|
|
dropout=0.0,
|
|
bias=False,
|
|
upcast_attention=False,
|
|
cross_attention_dim=None,
|
|
):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
self.heads = heads
|
|
self.inner_dim = inner_dim
|
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
|
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
|
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
|
|
self.to_out = nn.ModuleList([
|
|
nn.Linear(inner_dim, query_dim, bias=True),
|
|
nn.Dropout(dropout),
|
|
])
|
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
query = self.to_q(hidden_states)
|
|
batch_size, seq_len, _ = query.shape
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
key = self.to_k(encoder_hidden_states)
|
|
value = self.to_v(encoder_hidden_states)
|
|
|
|
head_dim = self.inner_dim // self.heads
|
|
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
|
|
|
hidden_states = F.scaled_dot_product_attention(
|
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
|
)
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
|
|
hidden_states = self.to_out[0](hidden_states)
|
|
hidden_states = self.to_out[1](hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class BasicTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
n_heads,
|
|
d_head,
|
|
dropout=0.0,
|
|
cross_attention_dim=None,
|
|
upcast_attention=False,
|
|
):
|
|
super().__init__()
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.attn1 = Attention(
|
|
query_dim=dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
dropout=dropout,
|
|
bias=False,
|
|
upcast_attention=upcast_attention,
|
|
)
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
self.attn2 = Attention(
|
|
query_dim=dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
dropout=dropout,
|
|
bias=False,
|
|
upcast_attention=upcast_attention,
|
|
cross_attention_dim=cross_attention_dim,
|
|
)
|
|
self.norm3 = nn.LayerNorm(dim)
|
|
self.ff = FeedForward(dim, dropout=dropout)
|
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
attn_output = self.attn1(self.norm1(hidden_states))
|
|
hidden_states = attn_output + hidden_states
|
|
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
|
|
hidden_states = attn_output + hidden_states
|
|
ff_output = self.ff(self.norm3(hidden_states))
|
|
hidden_states = ff_output + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class Transformer2DModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_attention_heads=16,
|
|
attention_head_dim=64,
|
|
in_channels=320,
|
|
num_layers=1,
|
|
dropout=0.0,
|
|
norm_num_groups=32,
|
|
cross_attention_dim=768,
|
|
upcast_attention=False,
|
|
use_linear_projection=False,
|
|
):
|
|
super().__init__()
|
|
self.num_attention_heads = num_attention_heads
|
|
self.attention_head_dim = attention_head_dim
|
|
inner_dim = num_attention_heads * attention_head_dim
|
|
self.use_linear_projection = use_linear_projection
|
|
|
|
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
|
|
|
|
if use_linear_projection:
|
|
self.proj_in = nn.Linear(in_channels, inner_dim, bias=True)
|
|
else:
|
|
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, bias=True)
|
|
|
|
self.transformer_blocks = nn.ModuleList([
|
|
BasicTransformerBlock(
|
|
dim=inner_dim,
|
|
n_heads=num_attention_heads,
|
|
d_head=attention_head_dim,
|
|
dropout=dropout,
|
|
cross_attention_dim=cross_attention_dim,
|
|
upcast_attention=upcast_attention,
|
|
)
|
|
for _ in range(num_layers)
|
|
])
|
|
|
|
if use_linear_projection:
|
|
self.proj_out = nn.Linear(inner_dim, in_channels, bias=True)
|
|
else:
|
|
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, bias=True)
|
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
batch, channel, height, width = hidden_states.shape
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
if self.use_linear_projection:
|
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
|
hidden_states = self.proj_in(hidden_states)
|
|
else:
|
|
hidden_states = self.proj_in(hidden_states)
|
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
|
|
|
|
for block in self.transformer_blocks:
|
|
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
|
|
|
if self.use_linear_projection:
|
|
hidden_states = self.proj_out(hidden_states)
|
|
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
|
else:
|
|
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
|
|
hidden_states = self.proj_out(hidden_states)
|
|
|
|
hidden_states = hidden_states + residual
|
|
return hidden_states
|
|
|
|
|
|
# ===== Down/Up Blocks =====
|
|
|
|
class CrossAttnDownBlock2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
temb_channels=1280,
|
|
dropout=0.0,
|
|
num_layers=1,
|
|
transformer_layers_per_block=1,
|
|
resnet_eps=1e-6,
|
|
resnet_time_scale_shift="default",
|
|
resnet_act_fn="swish",
|
|
resnet_groups=32,
|
|
resnet_pre_norm=True,
|
|
cross_attention_dim=768,
|
|
attention_head_dim=1,
|
|
downsample=True,
|
|
use_linear_projection=False,
|
|
):
|
|
super().__init__()
|
|
self.has_cross_attention = True
|
|
|
|
resnets = []
|
|
attentions = []
|
|
|
|
for i in range(num_layers):
|
|
in_channels_i = in_channels if i == 0 else out_channels
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=in_channels_i,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=1.0,
|
|
pre_norm=resnet_pre_norm,
|
|
)
|
|
)
|
|
attentions.append(
|
|
Transformer2DModel(
|
|
num_attention_heads=attention_head_dim,
|
|
attention_head_dim=out_channels // attention_head_dim,
|
|
in_channels=out_channels,
|
|
num_layers=transformer_layers_per_block,
|
|
dropout=dropout,
|
|
norm_num_groups=resnet_groups,
|
|
cross_attention_dim=cross_attention_dim,
|
|
use_linear_projection=use_linear_projection,
|
|
)
|
|
)
|
|
|
|
self.attentions = nn.ModuleList(attentions)
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
if downsample:
|
|
self.downsamplers = nn.ModuleList([
|
|
Downsample2D(out_channels, out_channels, padding=1)
|
|
])
|
|
else:
|
|
self.downsamplers = None
|
|
|
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
|
output_states = []
|
|
|
|
for resnet, attn in zip(self.resnets, self.attentions):
|
|
hidden_states = resnet(hidden_states, temb)
|
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
|
output_states.append(hidden_states)
|
|
|
|
if self.downsamplers is not None:
|
|
for downsampler in self.downsamplers:
|
|
hidden_states = downsampler(hidden_states)
|
|
output_states.append(hidden_states)
|
|
|
|
return hidden_states, tuple(output_states)
|
|
|
|
|
|
class DownBlock2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
temb_channels=1280,
|
|
dropout=0.0,
|
|
num_layers=1,
|
|
resnet_eps=1e-6,
|
|
resnet_time_scale_shift="default",
|
|
resnet_act_fn="swish",
|
|
resnet_groups=32,
|
|
resnet_pre_norm=True,
|
|
downsample=True,
|
|
):
|
|
super().__init__()
|
|
self.has_cross_attention = False
|
|
|
|
resnets = []
|
|
for i in range(num_layers):
|
|
in_channels_i = in_channels if i == 0 else out_channels
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=in_channels_i,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=1.0,
|
|
pre_norm=resnet_pre_norm,
|
|
)
|
|
)
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
if downsample:
|
|
self.downsamplers = nn.ModuleList([
|
|
Downsample2D(out_channels, out_channels, padding=1)
|
|
])
|
|
else:
|
|
self.downsamplers = None
|
|
|
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
|
output_states = []
|
|
for resnet in self.resnets:
|
|
hidden_states = resnet(hidden_states, temb)
|
|
output_states.append(hidden_states)
|
|
|
|
if self.downsamplers is not None:
|
|
for downsampler in self.downsamplers:
|
|
hidden_states = downsampler(hidden_states)
|
|
output_states.append(hidden_states)
|
|
|
|
return hidden_states, tuple(output_states)
|
|
|
|
|
|
class CrossAttnUpBlock2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
prev_output_channel,
|
|
temb_channels=1280,
|
|
dropout=0.0,
|
|
num_layers=1,
|
|
transformer_layers_per_block=1,
|
|
resnet_eps=1e-6,
|
|
resnet_time_scale_shift="default",
|
|
resnet_act_fn="swish",
|
|
resnet_groups=32,
|
|
resnet_pre_norm=True,
|
|
cross_attention_dim=768,
|
|
attention_head_dim=1,
|
|
upsample=True,
|
|
use_linear_projection=False,
|
|
):
|
|
super().__init__()
|
|
self.has_cross_attention = True
|
|
|
|
resnets = []
|
|
attentions = []
|
|
|
|
for i in range(num_layers):
|
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
|
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=resnet_in_channels + res_skip_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=1.0,
|
|
pre_norm=resnet_pre_norm,
|
|
)
|
|
)
|
|
attentions.append(
|
|
Transformer2DModel(
|
|
num_attention_heads=attention_head_dim,
|
|
attention_head_dim=out_channels // attention_head_dim,
|
|
in_channels=out_channels,
|
|
num_layers=transformer_layers_per_block,
|
|
dropout=dropout,
|
|
norm_num_groups=resnet_groups,
|
|
cross_attention_dim=cross_attention_dim,
|
|
use_linear_projection=use_linear_projection,
|
|
)
|
|
)
|
|
|
|
self.attentions = nn.ModuleList(attentions)
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
if upsample:
|
|
self.upsamplers = nn.ModuleList([
|
|
Upsample2D(out_channels, out_channels)
|
|
])
|
|
else:
|
|
self.upsamplers = None
|
|
|
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
|
for resnet, attn in zip(self.resnets, self.attentions):
|
|
res_hidden_states = res_hidden_states_tuple[-1]
|
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
|
hidden_states = resnet(hidden_states, temb)
|
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
|
|
|
if self.upsamplers is not None:
|
|
for upsampler in self.upsamplers:
|
|
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class UpBlock2D(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
prev_output_channel,
|
|
temb_channels=1280,
|
|
dropout=0.0,
|
|
num_layers=1,
|
|
resnet_eps=1e-6,
|
|
resnet_time_scale_shift="default",
|
|
resnet_act_fn="swish",
|
|
resnet_groups=32,
|
|
resnet_pre_norm=True,
|
|
upsample=True,
|
|
):
|
|
super().__init__()
|
|
self.has_cross_attention = False
|
|
|
|
resnets = []
|
|
for i in range(num_layers):
|
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
|
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=resnet_in_channels + res_skip_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=1.0,
|
|
pre_norm=resnet_pre_norm,
|
|
)
|
|
)
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
if upsample:
|
|
self.upsamplers = nn.ModuleList([
|
|
Upsample2D(out_channels, out_channels)
|
|
])
|
|
else:
|
|
self.upsamplers = None
|
|
|
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
|
|
for resnet in self.resnets:
|
|
res_hidden_states = res_hidden_states_tuple[-1]
|
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
|
hidden_states = resnet(hidden_states, temb)
|
|
|
|
if self.upsamplers is not None:
|
|
for upsampler in self.upsamplers:
|
|
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
|
|
|
|
return hidden_states
|
|
|
|
|
|
# ===== UNet Mid Block =====
|
|
|
|
class UNetMidBlock2DCrossAttn(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
temb_channels=1280,
|
|
dropout=0.0,
|
|
num_layers=1,
|
|
transformer_layers_per_block=1,
|
|
resnet_eps=1e-6,
|
|
resnet_time_scale_shift="default",
|
|
resnet_act_fn="swish",
|
|
resnet_groups=32,
|
|
resnet_pre_norm=True,
|
|
cross_attention_dim=768,
|
|
attention_head_dim=1,
|
|
use_linear_projection=False,
|
|
):
|
|
super().__init__()
|
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
|
|
|
resnets = [
|
|
ResnetBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=1.0,
|
|
pre_norm=resnet_pre_norm,
|
|
)
|
|
]
|
|
attentions = []
|
|
|
|
for _ in range(num_layers):
|
|
attentions.append(
|
|
Transformer2DModel(
|
|
num_attention_heads=attention_head_dim,
|
|
attention_head_dim=in_channels // attention_head_dim,
|
|
in_channels=in_channels,
|
|
num_layers=transformer_layers_per_block,
|
|
dropout=dropout,
|
|
norm_num_groups=resnet_groups,
|
|
cross_attention_dim=cross_attention_dim,
|
|
use_linear_projection=use_linear_projection,
|
|
)
|
|
)
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
dropout=dropout,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
non_linearity=resnet_act_fn,
|
|
output_scale_factor=1.0,
|
|
pre_norm=resnet_pre_norm,
|
|
)
|
|
)
|
|
|
|
self.attentions = nn.ModuleList(attentions)
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
|
hidden_states = self.resnets[0](hidden_states, temb)
|
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
|
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
|
hidden_states = resnet(hidden_states, temb)
|
|
return hidden_states
|
|
|
|
|
|
# ===== Downsample / Upsample =====
|
|
|
|
class Downsample2D(nn.Module):
|
|
def __init__(self, in_channels, out_channels, padding=1):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
|
|
self.padding = padding
|
|
|
|
def forward(self, hidden_states):
|
|
if self.padding == 0:
|
|
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
|
return self.conv(hidden_states)
|
|
|
|
|
|
class Upsample2D(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
|
|
def forward(self, hidden_states, upsample_size=None):
|
|
if upsample_size is not None:
|
|
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
|
|
else:
|
|
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
|
return self.conv(hidden_states)
|
|
|
|
|
|
# ===== SDXL UNet2DConditionModel =====
|
|
|
|
class SDXLUNet2DConditionModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
sample_size=128,
|
|
in_channels=4,
|
|
out_channels=4,
|
|
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
|
|
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
|
|
block_out_channels=(320, 640, 1280),
|
|
layers_per_block=2,
|
|
cross_attention_dim=2048,
|
|
attention_head_dim=5,
|
|
transformer_layers_per_block=1,
|
|
norm_num_groups=32,
|
|
norm_eps=1e-5,
|
|
dropout=0.0,
|
|
act_fn="silu",
|
|
time_embedding_type="positional",
|
|
flip_sin_to_cos=True,
|
|
freq_shift=0,
|
|
time_embedding_dim=None,
|
|
resnet_time_scale_shift="default",
|
|
upcast_attention=False,
|
|
use_linear_projection=False,
|
|
addition_embed_type=None,
|
|
addition_time_embed_dim=None,
|
|
projection_class_embeddings_input_dim=None,
|
|
):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.sample_size = sample_size
|
|
self.addition_embed_type = addition_embed_type
|
|
|
|
if isinstance(attention_head_dim, int):
|
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
|
if isinstance(transformer_layers_per_block, int):
|
|
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
|
|
|
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
|
|
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
|
time_embed_dim = block_out_channels[0] * 4
|
|
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
|
|
|
|
if addition_embed_type == "text_time":
|
|
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
|
|
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
|
|
|
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
|
|
|
|
self.down_blocks = nn.ModuleList()
|
|
output_channel = block_out_channels[0]
|
|
for i, down_block_type in enumerate(down_block_types):
|
|
input_channel = output_channel
|
|
output_channel = block_out_channels[i]
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
|
|
if "CrossAttn" in down_block_type:
|
|
down_block = CrossAttnDownBlock2D(
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
temb_channels=time_embed_dim,
|
|
dropout=dropout,
|
|
num_layers=layers_per_block,
|
|
transformer_layers_per_block=transformer_layers_per_block[i],
|
|
resnet_eps=norm_eps,
|
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attention_head_dim=attention_head_dim[i],
|
|
downsample=not is_final_block,
|
|
use_linear_projection=use_linear_projection,
|
|
)
|
|
else:
|
|
down_block = DownBlock2D(
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
temb_channels=time_embed_dim,
|
|
dropout=dropout,
|
|
num_layers=layers_per_block,
|
|
resnet_eps=norm_eps,
|
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
downsample=not is_final_block,
|
|
)
|
|
self.down_blocks.append(down_block)
|
|
|
|
self.mid_block = UNetMidBlock2DCrossAttn(
|
|
in_channels=block_out_channels[-1],
|
|
temb_channels=time_embed_dim,
|
|
dropout=dropout,
|
|
num_layers=1,
|
|
transformer_layers_per_block=transformer_layers_per_block[-1],
|
|
resnet_eps=norm_eps,
|
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attention_head_dim=attention_head_dim[-1],
|
|
use_linear_projection=use_linear_projection,
|
|
)
|
|
|
|
self.up_blocks = nn.ModuleList()
|
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
|
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
|
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
|
output_channel = reversed_block_out_channels[0]
|
|
|
|
for i, up_block_type in enumerate(up_block_types):
|
|
prev_output_channel = output_channel
|
|
output_channel = reversed_block_out_channels[i]
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
|
|
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
|
|
|
if "CrossAttn" in up_block_type:
|
|
up_block = CrossAttnUpBlock2D(
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
prev_output_channel=prev_output_channel,
|
|
temb_channels=time_embed_dim,
|
|
dropout=dropout,
|
|
num_layers=layers_per_block + 1,
|
|
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
|
resnet_eps=norm_eps,
|
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attention_head_dim=reversed_attention_head_dim[i],
|
|
upsample=not is_final_block,
|
|
use_linear_projection=use_linear_projection,
|
|
)
|
|
else:
|
|
up_block = UpBlock2D(
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
prev_output_channel=prev_output_channel,
|
|
temb_channels=time_embed_dim,
|
|
dropout=dropout,
|
|
num_layers=layers_per_block + 1,
|
|
resnet_eps=norm_eps,
|
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
upsample=not is_final_block,
|
|
)
|
|
self.up_blocks.append(up_block)
|
|
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
|
self.conv_act = nn.SiLU()
|
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
|
|
|
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
|
|
timesteps = timestep
|
|
if not torch.is_tensor(timesteps):
|
|
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
|
timesteps = timesteps[None].to(sample.device)
|
|
|
|
t_emb = self.time_proj(timesteps)
|
|
t_emb = t_emb.to(dtype=sample.dtype)
|
|
emb = self.time_embedding(t_emb)
|
|
|
|
if self.addition_embed_type == "text_time":
|
|
text_embeds = added_cond_kwargs.get("text_embeds")
|
|
time_ids = added_cond_kwargs.get("time_ids")
|
|
time_embeds = self.add_time_proj(time_ids.flatten())
|
|
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
|
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
|
add_embeds = add_embeds.to(emb.dtype)
|
|
aug_emb = self.add_embedding(add_embeds)
|
|
emb = emb + aug_emb
|
|
|
|
sample = self.conv_in(sample)
|
|
|
|
down_block_res_samples = (sample,)
|
|
for down_block in self.down_blocks:
|
|
sample, res_samples = down_block(
|
|
hidden_states=sample,
|
|
temb=emb,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
down_block_res_samples += res_samples
|
|
|
|
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
|
|
|
for up_block in self.up_blocks:
|
|
res_samples = down_block_res_samples[-len(up_block.resnets):]
|
|
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
|
|
|
|
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
|
|
sample = up_block(
|
|
hidden_states=sample,
|
|
temb=emb,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
res_hidden_states_tuple=res_samples,
|
|
upsample_size=upsample_size,
|
|
)
|
|
|
|
sample = self.conv_norm_out(sample)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample)
|
|
|
|
if not return_dict:
|
|
return (sample,)
|
|
return sample
|