mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 06:46:13 +00:00
sd
This commit is contained in:
642
diffsynth/models/stable_diffusion_vae.py
Normal file
642
diffsynth/models/stable_diffusion_vae.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution:
|
||||
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
||||
)
|
||||
|
||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
||||
# randn_like doesn't accept generator on all torch versions
|
||||
sample = torch.randn(self.mean.shape, generator=generator,
|
||||
device=self.parameters.device, dtype=self.parameters.dtype)
|
||||
return self.mean + self.std * sample
|
||||
|
||||
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
|
||||
if self.deterministic:
|
||||
return torch.tensor([0.0])
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def mode(self) -> torch.Tensor:
|
||||
return self.mean
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
use_in_shortcut=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
if self.time_embedding_norm == "default":
|
||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif non_linearity == "gelu":
|
||||
self.nonlinearity = nn.GELU()
|
||||
elif non_linearity == "relu":
|
||||
self.nonlinearity = nn.ReLU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported non_linearity: {non_linearity}")
|
||||
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.conv_shortcut = None
|
||||
if conv_shortcut:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
|
||||
|
||||
def forward(self, input_tensor, temb=None):
|
||||
hidden_states = input_tensor
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if temb is not None:
|
||||
temb = self.nonlinearity(temb)
|
||||
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "default":
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
return output_tensor
|
||||
|
||||
|
||||
class DownEncoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
in_channels_i = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels_i,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList([
|
||||
Downsample2D(out_channels, out_channels, padding=downsample_padding)
|
||||
])
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
def forward(self, hidden_states, *args, **kwargs):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=None)
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpDecoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
temb_channels=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
in_channels_i = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels_i,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([
|
||||
Upsample2D(out_channels, out_channels)
|
||||
])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=temb)
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetMidBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
temb_channels=None,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_act_fn="swish",
|
||||
resnet_groups=32,
|
||||
resnet_pre_norm=True,
|
||||
add_attention=True,
|
||||
attention_head_dim=1,
|
||||
output_scale_factor=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
self.add_attention = add_attention
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = in_channels
|
||||
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
in_channels,
|
||||
num_groups=resnet_groups,
|
||||
eps=resnet_eps,
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(None)
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Simple attention block for VAE mid block.
|
||||
Mirrors diffusers Attention class with AttnProcessor2_0 for VAE use case.
|
||||
Uses modern key names (to_q, to_k, to_v, to_out) matching in-memory diffusers structure.
|
||||
Checkpoint uses deprecated keys (query, key, value, proj_attn) — mapped via converter.
|
||||
"""
|
||||
def __init__(self, channels, num_groups=32, eps=1e-6):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
self.heads = 1
|
||||
self.rescale_output_factor = 1.0
|
||||
|
||||
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=eps, affine=True)
|
||||
self.to_q = nn.Linear(channels, channels, bias=True)
|
||||
self.to_k = nn.Linear(channels, channels, bias=True)
|
||||
self.to_v = nn.Linear(channels, channels, bias=True)
|
||||
self.to_out = nn.ModuleList([
|
||||
nn.Linear(channels, channels, bias=True),
|
||||
nn.Dropout(0.0),
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
|
||||
# Group norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
# Flatten spatial dims: (B, C, H, W) -> (B, H*W, C)
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
# QKV projection
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
|
||||
# Reshape for attention: (B, seq, dim) -> (B, heads, seq, head_dim)
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // self.heads
|
||||
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Scaled dot-product attention
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Reshape back: (B, heads, seq, head_dim) -> (B, seq, heads*head_dim)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Output projection + dropout
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
# Reshape back to 4D and add residual
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# Rescale output factor
|
||||
hidden_states = hidden_states / self.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
"""Downsampling layer matching diffusers Downsample2D with use_conv=True.
|
||||
Key names: conv.weight/bias.
|
||||
When padding=0, applies explicit F.pad before conv to match dimension.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, padding=1):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.padding = padding
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if self.padding == 0:
|
||||
import torch.nn.functional as F
|
||||
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
"""Upsampling layer with key names matching diffusers checkpoint: conv.weight/bias."""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
return self.conv(hidden_states)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
double_z=True,
|
||||
mid_block_add_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
down_block = DownEncoderBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
add_downsample=not is_final_block,
|
||||
downsample_padding=0,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default",
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=None,
|
||||
add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.conv_in(sample)
|
||||
for down_block in self.down_blocks:
|
||||
sample = down_block(sample)
|
||||
sample = self.mid_block(sample)
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
norm_type="group",
|
||||
mid_block_add_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
temb_channels = in_channels if norm_type == "spatial" else None
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=temb_channels,
|
||||
add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
up_block = UpDecoderBlock2D(
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
add_upsample=not is_final_block,
|
||||
temb_channels=temb_channels,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, sample, latent_embeds=None):
|
||||
sample = self.conv_in(sample)
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
for up_block in self.up_blocks:
|
||||
sample = up_block(sample, latent_embeds)
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class StableDiffusionVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
|
||||
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
|
||||
block_out_channels=(128, 256, 512, 512),
|
||||
layers_per_block=2,
|
||||
act_fn="silu",
|
||||
latent_channels=4,
|
||||
norm_num_groups=32,
|
||||
sample_size=512,
|
||||
scaling_factor=0.18215,
|
||||
shift_factor=None,
|
||||
latents_mean=None,
|
||||
latents_std=None,
|
||||
force_upcast=True,
|
||||
use_quant_conv=True,
|
||||
use_post_quant_conv=True,
|
||||
mid_block_add_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
act_fn=act_fn,
|
||||
double_z=True,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
act_fn=act_fn,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
||||
|
||||
self.latents_mean = latents_mean
|
||||
self.latents_std = latents_std
|
||||
self.scaling_factor = scaling_factor
|
||||
self.shift_factor = shift_factor
|
||||
self.sample_size = sample_size
|
||||
self.force_upcast = force_upcast
|
||||
|
||||
def _encode(self, x):
|
||||
h = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def encode(self, x):
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
return posterior
|
||||
|
||||
def _decode(self, z):
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
return self.decoder(z)
|
||||
|
||||
def decode(self, z):
|
||||
return self._decode(z)
|
||||
|
||||
def forward(self, sample, sample_posterior=True, return_dict=True, generator=None):
|
||||
posterior = self.encode(sample)
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
# Scale latent
|
||||
z = z * self.scaling_factor
|
||||
decode = self.decode(z)
|
||||
if return_dict:
|
||||
return {"sample": decode, "posterior": posterior, "latent_sample": z}
|
||||
return decode, posterior
|
||||
Reference in New Issue
Block a user