# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Dict, Optional, Tuple, Union, Callable import torch import torch.nn as nn from einops import rearrange import torch.nn.functional as F import inspect ACT2CLS = { "swish": nn.SiLU, "silu": nn.SiLU, "mish": nn.Mish, "gelu": nn.GELU, "relu": nn.ReLU, } def get_activation(act_fn: str) -> nn.Module: """Helper function to get activation function from string. Args: act_fn (str): Name of activation function. Returns: nn.Module: Activation function. """ act_fn = act_fn.lower() if act_fn in ACT2CLS: return ACT2CLS[act_fn]() else: raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}") class ResnetBlock2D(nn.Module): r""" A Resnet block. Parameters: in_channels (`int`): The number of channels in the input. out_channels (`int`, *optional*, default to be `None`): The number of output channels for the first conv2d layer. If None, same as `in_channels`. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. groups_out (`int`, *optional*, default to None): The number of groups to use for the second normalization layer. if set to None, same as `groups`. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a stronger conditioning with scale and shift. kernel (`torch.Tensor`, optional, default to None): FIR filter, see [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. use_in_shortcut (`bool`, *optional*, default to `True`): If `True`, add a 1x1 nn.conv2d layer for skip-connection. up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the `conv_shortcut` output. conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. If None, same as `out_channels`. """ def __init__( self, *, in_channels: int, out_channels: Optional[int] = None, conv_shortcut: bool = False, dropout: float = 0.0, temb_channels: int = 512, groups: int = 32, groups_out: Optional[int] = None, pre_norm: bool = True, eps: float = 1e-6, non_linearity: str = "swish", skip_time_act: bool = False, time_embedding_norm: str = "default", # default, scale_shift, kernel: Optional[torch.Tensor] = None, output_scale_factor: float = 1.0, use_in_shortcut: Optional[bool] = None, up: bool = False, down: bool = False, conv_shortcut_bias: bool = True, conv_2d_out_channels: Optional[int] = None, ): super().__init__() if time_embedding_norm == "ada_group": raise ValueError( "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead", ) if time_embedding_norm == "spatial": raise ValueError( "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead", ) self.pre_norm = True self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.up = up self.down = down self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act if groups_out is None: groups_out = groups self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": self.time_emb_proj = nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels) else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") else: self.time_emb_proj = None self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) self.upsample = self.downsample = None if self.up: if kernel == "fir": fir_kernel = (1, 3, 3, 1) self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) elif kernel == "sde_vp": self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") else: self.upsample = Upsample2D(in_channels, use_conv=False) elif self.down: if kernel == "fir": fir_kernel = (1, 3, 3, 1) self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) elif kernel == "sde_vp": self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) else: self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias, ) def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) temb = self.time_emb_proj(temb)[:, :, None, None] if self.time_embedding_norm == "default": if temb is not None: hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) elif self.time_embedding_norm == "scale_shift": if temb is None: raise ValueError( f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}" ) time_scale, time_shift = torch.chunk(temb, 2, dim=1) hidden_states = self.norm2(hidden_states) hidden_states = hidden_states * (1 + time_scale) + time_shift else: hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor.contiguous()) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor class Downsample2D(nn.Module): """A 2D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. name (`str`, default `conv`): name of the downsampling 2D layer. """ def __init__( self, channels: int, use_conv: bool = False, out_channels: Optional[int] = None, padding: int = 1, name: str = "conv", kernel_size=3, norm_type=None, eps=None, elementwise_affine=None, bias=True, ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) elif norm_type == "rms_norm": self.norm = RMSNorm(channels, eps, elementwise_affine) elif norm_type is None: self.norm = None else: raise ValueError(f"unknown norm_type: {norm_type}") if use_conv: conv = nn.Conv2d( self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.Conv2d_0 = conv self.conv = conv elif name == "Conv2d_0": self.conv = conv else: self.conv = conv def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) assert hidden_states.shape[1] == self.channels if self.norm is not None: hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states) return hidden_states class Upsample2D(nn.Module): """A 2D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. name (`str`, default `conv`): name of the upsampling 2D layer. """ def __init__( self, channels: int, use_conv: bool = False, use_conv_transpose: bool = False, out_channels: Optional[int] = None, name: str = "conv", kernel_size: Optional[int] = None, padding=1, norm_type=None, eps=None, elementwise_affine=None, bias=True, interpolate=True, ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name self.interpolate = interpolate if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) elif norm_type == "rms_norm": self.norm = RMSNorm(channels, eps, elementwise_affine) elif norm_type is None: self.norm = None else: raise ValueError(f"unknown norm_type: {norm_type}") conv = None if use_conv_transpose: if kernel_size is None: kernel_size = 4 conv = nn.ConvTranspose2d( channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias ) elif use_conv: if kernel_size is None: kernel_size = 3 conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.conv = conv else: self.Conv2d_0 = conv def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) assert hidden_states.shape[1] == self.channels if self.norm is not None: hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) if self.use_conv_transpose: return self.conv(hidden_states) # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1 # https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767 dtype = hidden_states.dtype if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: hidden_states = hidden_states.contiguous() # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if self.interpolate: # upsample_nearest_nhwc also fails when the number of output elements is large # https://github.com/pytorch/pytorch/issues/141831 scale_factor = ( 2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])]) ) if hidden_states.numel() * scale_factor > pow(2, 31): hidden_states = hidden_states.contiguous() if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # Cast back to original dtype if dtype == torch.bfloat16: hidden_states = hidden_states.to(dtype) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": hidden_states = self.conv(hidden_states) else: hidden_states = self.Conv2d_0(hidden_states) return hidden_states class Attention(nn.Module): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. kv_heads (`int`, *optional*, defaults to `None`): The number of key and value heads to use for multi-head attention. Defaults to `heads`. If `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. upcast_attention (`bool`, *optional*, defaults to False): Set to `True` to upcast the attention computation to `float32`. upcast_softmax (`bool`, *optional*, defaults to False): Set to `True` to upcast the softmax computation to `float32`. cross_attention_norm (`str`, *optional*, defaults to `None`): The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the group norm in the cross attention. added_kv_proj_dim (`int`, *optional*, defaults to `None`): The number of channels to use for the added key and value projections. If `None`, no projection is used. norm_num_groups (`int`, *optional*, defaults to `None`): The number of groups to use for the group norm in the attention. spatial_norm_dim (`int`, *optional*, defaults to `None`): The number of channels to use for the spatial normalization. out_bias (`bool`, *optional*, defaults to `True`): Set to `True` to use a bias in the output linear layer. scale_qk (`bool`, *optional*, defaults to `True`): Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. only_cross_attention (`bool`, *optional*, defaults to `False`): Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if `added_kv_proj_dim` is not `None`. eps (`float`, *optional*, defaults to 1e-5): An additional value added to the denominator in group normalization that is used for numerical stability. rescale_output_factor (`float`, *optional*, defaults to 1.0): A factor to rescale the output by dividing it with this value. residual_connection (`bool`, *optional*, defaults to `False`): Set to `True` to add the residual connection to the output. _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): Set to `True` if the attention block is loaded from a deprecated state dict. processor (`AttnProcessor`, *optional*, defaults to `None`): The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and `AttnProcessor` otherwise. """ def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, kv_heads: Optional[int] = None, dim_head: int = 64, dropout: float = 0.0, bias: bool = False, upcast_attention: bool = False, upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, added_proj_bias: Optional[bool] = True, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, out_dim: int = None, out_context_dim: int = None, context_pre_only=None, pre_only=False, elementwise_affine: bool = True, is_causal: bool = False, ): super().__init__() # To prevent circular import. # from .normalization import FP32LayerNorm, LpNorm, RMSNorm self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads self.query_dim = query_dim self.use_bias = bias self.is_cross_attention = cross_attention_dim is not None self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout self.fused_projections = False self.out_dim = out_dim if out_dim is not None else query_dim self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim self.context_pre_only = context_pre_only self.pre_only = pre_only self.is_causal = is_causal # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = out_dim // dim_head if out_dim is not None else heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention if self.added_kv_proj_dim is None and self.only_cross_attention: raise ValueError( "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." ) if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None if spatial_norm_dim is not None: self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) else: self.spatial_norm = None if qk_norm is None: self.norm_q = None self.norm_k = None elif qk_norm == "layer_norm": self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) elif qk_norm == "fp32_layer_norm": self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) elif qk_norm == "layer_norm_across_heads": # Lumina applies qk norm across all heads self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) elif qk_norm == "rms_norm": self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) elif qk_norm == "rms_norm_across_heads": # LTX applies qk norm across all heads self.norm_q = RMSNorm(dim_head * heads, eps=eps) self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) elif qk_norm == "l2": self.norm_q = LpNorm(p=2, dim=-1, eps=eps) self.norm_k = LpNorm(p=2, dim=-1, eps=eps) else: raise ValueError( f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." ) if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(self.cross_attention_dim) elif cross_attention_norm == "group_norm": if self.added_kv_proj_dim is not None: # The given `encoder_hidden_states` are initially of shape # (batch_size, seq_len, added_kv_proj_dim) before being projected # to (batch_size, seq_len, cross_attention_dim). The norm is applied # before the projection, so we need to use `added_kv_proj_dim` as # the number of channels for the group norm. norm_cross_num_channels = added_kv_proj_dim else: norm_cross_num_channels = self.cross_attention_dim self.norm_cross = nn.GroupNorm( num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True ) else: raise ValueError( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) else: self.to_k = None self.to_v = None self.added_proj_bias = added_proj_bias if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) if self.context_pre_only is not None: self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) else: self.add_q_proj = None self.add_k_proj = None self.add_v_proj = None if not self.pre_only: self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) else: self.to_out = None if self.context_pre_only is not None and not self.context_pre_only: self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) else: self.to_add_out = None if qk_norm is not None and added_kv_proj_dim is not None: if qk_norm == "layer_norm": self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) elif qk_norm == "fp32_layer_norm": self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) elif qk_norm == "rms_norm": self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps) elif qk_norm == "rms_norm_across_heads": # Wan applies qk norm across all heads # Wan also doesn't apply a q norm self.norm_added_q = None self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps) else: raise ValueError( f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" ) else: self.norm_added_q = None self.norm_added_k = None # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_use_xla_flash_attention( self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None, is_flux=False, ) -> None: r""" Set whether to use xla flash attention from `torch_xla` or not. Args: use_xla_flash_attention (`bool`): Whether to use pallas flash attention kernel from `torch_xla` or not. partition_spec (`Tuple[]`, *optional*): Specify the partition specification if using SPMD. Otherwise None. """ if use_xla_flash_attention: if not is_torch_xla_available: raise "torch_xla is not available" elif is_torch_xla_version("<", "2.3"): raise "flash attention pallas kernel is supported from torch_xla version 2.3" elif is_spmd() and is_torch_xla_version("<", "2.4"): raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4" else: if is_flux: processor = XLAFluxFlashAttnProcessor2_0(partition_spec) else: processor = XLAFlashAttnProcessor2_0(partition_spec) else: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: r""" Set whether to use npu flash attention from `torch_npu` or not. """ if use_npu_flash_attention: processor = AttnProcessorNPU() else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ) -> None: r""" Set whether to use memory efficient attention from `xformers` or not. Args: use_memory_efficient_attention_xformers (`bool`): Whether to use memory efficient attention from `xformers` or not. attention_op (`Callable`, *optional*): The attention operation to use. Defaults to `None` which uses the default attention operation from `xformers`. """ is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), ) is_added_kv_processor = hasattr(self, "processor") and isinstance( self.processor, ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, SlicedAttnAddedKVProcessor, XFormersAttnAddedKVProcessor, ), ) is_ip_adapter = hasattr(self, "processor") and isinstance( self.processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor), ) is_joint_processor = hasattr(self, "processor") and isinstance( self.processor, ( JointAttnProcessor2_0, XFormersJointAttnProcessor, ), ) if use_memory_efficient_attention_xformers: if is_added_kv_processor and is_custom_diffusion: raise NotImplementedError( f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}" ) if not is_xformers_available(): raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers" ), name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) else: try: # Make sure we can run the memory efficient attention dtype = None if attention_op is not None: op_fw, op_bw = attention_op dtype, *_ = op_fw.SUPPORTED_DTYPES q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) _ = xformers.ops.memory_efficient_attention(q, q, q) except Exception as e: raise e if is_custom_diffusion: processor = CustomDiffusionXFormersAttnProcessor( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, attention_op=attention_op, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) elif is_added_kv_processor: # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP # which uses this type of cross attention ONLY because the attention mask of format # [0, ..., -10.000, ..., 0, ...,] is not supported # throw warning logger.info( "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." ) processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) elif is_ip_adapter: processor = IPAdapterXFormersAttnProcessor( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, num_tokens=self.processor.num_tokens, scale=self.processor.scale, attention_op=attention_op, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_ip"): processor.to( device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype ) elif is_joint_processor: processor = XFormersJointAttnProcessor(attention_op=attention_op) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: if is_custom_diffusion: attn_processor_class = ( CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor ) processor = attn_processor_class( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) elif is_ip_adapter: processor = IPAdapterAttnProcessor2_0( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, num_tokens=self.processor.num_tokens, scale=self.processor.scale, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_ip"): processor.to( device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype ) else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_attention_slice(self, slice_size: int) -> None: r""" Set the slice size for attention computation. Args: slice_size (`int`): The slice size for attention computation. """ if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") if slice_size is not None and self.added_kv_proj_dim is not None: processor = SlicedAttnAddedKVProcessor(slice_size) elif slice_size is not None: processor = SlicedAttnProcessor(slice_size) elif self.added_kv_proj_dim is not None: processor = AttnAddedKVProcessor() else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_processor(self, processor: "AttnProcessor") -> None: r""" Set the attention processor to use. Args: processor (`AttnProcessor`): The attention processor to use. """ # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( hasattr(self, "processor") and isinstance(self.processor, torch.nn.Module) and not isinstance(processor, torch.nn.Module) ): logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") self._modules.pop("processor") self.processor = processor def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": r""" Get the attention processor in use. Args: return_deprecated_lora (`bool`, *optional*, defaults to `False`): Set to `True` to return the deprecated LoRA attention processor. Returns: "AttentionProcessor": The attention processor in use. """ if not return_deprecated_lora: return self.processor def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **cross_attention_kwargs, ) -> torch.Tensor: r""" The forward method of the `Attention` class. Args: hidden_states (`torch.Tensor`): The hidden states of the query. encoder_hidden_states (`torch.Tensor`, *optional*): The hidden states of the encoder. attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. **cross_attention_kwargs: Additional keyword arguments to pass along to the cross attention. Returns: `torch.Tensor`: The output of the attention layer. """ # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} unused_kwargs = [ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters ] if len(unused_kwargs) > 0: logger.warning( f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." ) cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: r""" Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` is the number of heads initialized while constructing the `Attention` class. Args: tensor (`torch.Tensor`): The tensor to reshape. Returns: `torch.Tensor`: The reshaped tensor. """ head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: r""" Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is the number of heads initialized while constructing the `Attention` class. Args: tensor (`torch.Tensor`): The tensor to reshape. out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is reshaped to `[batch_size * heads, seq_len, dim // heads]`. Returns: `torch.Tensor`: The reshaped tensor. """ head_size = self.heads if tensor.ndim == 3: batch_size, seq_len, dim = tensor.shape extra_dim = 1 else: batch_size, extra_dim, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3) if out_dim == 3: tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) return tensor def get_attention_scores( self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: r""" Compute the attention scores. Args: query (`torch.Tensor`): The query tensor. key (`torch.Tensor`): The key tensor. attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. Returns: `torch.Tensor`: The attention probabilities/scores. """ dtype = query.dtype if self.upcast_attention: query = query.float() key = key.float() if attention_mask is None: baddbmm_input = torch.empty( query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device ) beta = 0 else: baddbmm_input = attention_mask beta = 1 attention_scores = torch.baddbmm( baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale, ) del baddbmm_input if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) del attention_scores attention_probs = attention_probs.to(dtype) return attention_probs def prepare_attention_mask( self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 ) -> torch.Tensor: r""" Prepare the attention mask for the attention computation. Args: attention_mask (`torch.Tensor`): The attention mask to prepare. target_length (`int`): The target length of the attention mask. This is the length of the attention mask after padding. batch_size (`int`): The batch size, which is used to repeat the attention mask. out_dim (`int`, *optional*, defaults to `3`): The output dimension of the attention mask. Can be either `3` or `4`. Returns: `torch.Tensor`: The prepared attention mask. """ head_size = self.heads if attention_mask is None: return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: # we want to instead pad by (0, remaining_length), where remaining_length is: # remaining_length: int = target_length - current_length # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: attention_mask = attention_mask.repeat_interleave( head_size, dim=0, output_size=attention_mask.shape[0] * head_size ) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.repeat_interleave( head_size, dim=1, output_size=attention_mask.shape[1] * head_size ) return attention_mask def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: r""" Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the `Attention` class. Args: encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. Returns: `torch.Tensor`: The normalized encoder hidden states. """ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" if isinstance(self.norm_cross, nn.LayerNorm): encoder_hidden_states = self.norm_cross(encoder_hidden_states) elif isinstance(self.norm_cross, nn.GroupNorm): # Group norm norms along the channels dimension and expects # input to be in the shape of (N, C, *). In this case, we want # to norm along the hidden dimension, so we need to move # (batch_size, sequence_length, hidden_size) -> # (batch_size, hidden_size, sequence_length) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) encoder_hidden_states = self.norm_cross(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) else: assert False return encoder_hidden_states @torch.no_grad() def fuse_projections(self, fuse=True): device = self.to_q.weight.data.device dtype = self.to_q.weight.data.dtype if not self.is_cross_attention: # fetch weight matrices. concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) in_features = concatenated_weights.shape[1] out_features = concatenated_weights.shape[0] # create a new single projection layer and copy over the weights. self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) self.to_qkv.weight.copy_(concatenated_weights) if self.use_bias: concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) self.to_qkv.bias.copy_(concatenated_bias) else: concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) in_features = concatenated_weights.shape[1] out_features = concatenated_weights.shape[0] self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) self.to_kv.weight.copy_(concatenated_weights) if self.use_bias: concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) self.to_kv.bias.copy_(concatenated_bias) # handle added projections for SD3 and others. if ( getattr(self, "add_q_proj", None) is not None and getattr(self, "add_k_proj", None) is not None and getattr(self, "add_v_proj", None) is not None ): concatenated_weights = torch.cat( [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] ) in_features = concatenated_weights.shape[1] out_features = concatenated_weights.shape[0] self.to_added_qkv = nn.Linear( in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype ) self.to_added_qkv.weight.copy_(concatenated_weights) if self.added_proj_bias: concatenated_bias = torch.cat( [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] ) self.to_added_qkv.bias.copy_(concatenated_bias) self.fused_projections = fuse class AttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class UNetMidBlock2D(nn.Module): """ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. Args: in_channels (`int`): The number of input channels. temb_channels (`int`): The number of temporal embedding channels. dropout (`float`, *optional*, defaults to 0.0): The dropout rate. num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. resnet_time_scale_shift (`str`, *optional*, defaults to `default`): The type of normalization to apply to the time embeddings. This can help to improve the performance of the model on tasks with long-range temporal dependencies. resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. resnet_groups (`int`, *optional*, defaults to 32): The number of groups to use in the group normalization layers of the resnet blocks. attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. resnet_pre_norm (`bool`, *optional*, defaults to `True`): Whether to use pre-normalization for the resnet blocks. add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. attention_head_dim (`int`, *optional*, defaults to 1): Dimension of a single attention head. The number of attention heads is determined based on this value and the number of input channels. output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. Returns: `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, height, width)`. """ def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, attn_groups: Optional[int] = None, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention if attn_groups is None: attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None # there is always at least one resnet if resnet_time_scale_shift == "spatial": resnets = [ ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ] else: resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] if attention_head_dim is None: logger.warning( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." ) attention_head_dim = in_channels for _ in range(num_layers): if self.add_attention: attentions.append( Attention( in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=attn_groups, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) else: attentions.append(None) if resnet_time_scale_shift == "spatial": resnets.append( ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ) else: resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if torch.is_grad_enabled() and self.gradient_checkpointing: if attn is not None: hidden_states = attn(hidden_states, temb=temb) hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) else: if attn is not None: hidden_states = attn(hidden_states, temb=temb) hidden_states = resnet(hidden_states, temb) return hidden_states class DownEncoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels if resnet_time_scale_shift == "spatial": resnets.append( ResnetBlockCondNorm2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ) else: resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class UpDecoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_upsample: bool = True, temb_channels: Optional[int] = None, ): super().__init__() resnets = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels if resnet_time_scale_shift == "spatial": resnets.append( ResnetBlockCondNorm2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm="spatial", non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, ) ) else: resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.resolution_idx = resolution_idx def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class Encoder(nn.Module): r""" The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. double_z (`bool`, *optional*, defaults to `True`): Whether to double the number of output channels for the last block. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, ) self.down_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = DownEncoderBlock2D( num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=not is_final_block, resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, # attention_head_dim=output_channel, # temb_channels=None, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, add_attention=mid_block_add_attention, ) # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) self.gradient_checkpointing = False def forward(self, sample: torch.Tensor) -> torch.Tensor: r"""The forward method of the `Encoder` class.""" sample = self.conv_in(sample) if torch.is_grad_enabled() and self.gradient_checkpointing: # down for down_block in self.down_blocks: sample = self._gradient_checkpointing_func(down_block, sample) # middle sample = self._gradient_checkpointing_func(self.mid_block, sample) else: # down for down_block in self.down_blocks: sample = down_block(sample) # middle sample = self.mid_block(sample) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class Decoder(nn.Module): r""" The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. norm_type (`str`, *optional*, defaults to `"group"`): The normalization type to use. Can be either `"group"` or `"spatial"`. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", norm_type: str = "group", # group, spatial mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv2d( in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1, ) self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, add_attention=mid_block_add_attention, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = UpDecoderBlock2D( num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, # prev_output_channel=prev_output_channel, add_upsample=not is_final_block, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, # attention_head_dim=output_channel, temb_channels=temb_channels, resnet_time_scale_shift=norm_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.gradient_checkpointing = False def forward( self, sample: torch.Tensor, latent_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""The forward method of the `Decoder` class.""" sample = self.conv_in(sample) if torch.is_grad_enabled() and self.gradient_checkpointing: # middle sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds) # up for up_block in self.up_blocks: sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds) else: # middle sample = self.mid_block(sample, latent_embeds) # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class Flux2VAE(torch.nn.Module): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. force_upcast (`bool`, *optional*, default to `True`): If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix mid_block_add_attention (`bool`, *optional*, default to `True`): If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the mid_block will only have resnet blocks """ _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ( "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", ), up_block_types: Tuple[str, ...] = ( "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", ), block_out_channels: Tuple[int, ...] = ( 128, 256, 512, 512, ), layers_per_block: int = 2, act_fn: str = "silu", latent_channels: int = 32, norm_num_groups: int = 32, sample_size: int = 1024, # YiYi notes: not sure force_upcast: bool = True, use_quant_conv: bool = True, use_post_quant_conv: bool = True, mid_block_add_attention: bool = True, batch_norm_eps: float = 1e-4, batch_norm_momentum: float = 0.1, patch_size: Tuple[int, int] = (2, 2), ): super().__init__() # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, mid_block_add_attention=mid_block_add_attention, ) # pass init params to Decoder self.decoder = Decoder( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, mid_block_add_attention=mid_block_add_attention, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None self.bn = nn.BatchNorm2d( math.prod(patch_size) * latent_channels, eps=batch_norm_eps, momentum=batch_norm_momentum, affine=False, track_running_stats=True, ) self.use_slicing = False self.use_tiling = False @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self): r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = x.shape if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): return self._tiled_encode(x) enc = self.encoder(x) if self.quant_conv is not None: enc = self.quant_conv(enc) return enc def encode( self, x: torch.Tensor, return_dict: bool = True ): """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and x.shape[0] > 1: encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: h = self._encode(x) h = rearrange(h, "B C (H P) (W Q) -> B (C P Q) H W", P=2, Q=2) h = h[:, :128] latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(h.device, h.dtype) latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to( h.device, h.dtype ) h = (h - latents_bn_mean) / latents_bn_std return h def _decode(self, z: torch.Tensor, return_dict: bool = True): if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) if self.post_quant_conv is not None: z = self.post_quant_conv(z) dec = self.decoder(z) if not return_dict: return (dec,) return dec def decode( self, z: torch.FloatTensor, return_dict: bool = True, generator=None ): latents_bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(z.device, z.dtype) latents_bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + 0.0001).to( z.device, z.dtype ) z = z * latents_bn_std + latents_bn_mean z = rearrange(z, "B (C P Q) H W -> B C (H P) (W Q)", P=2, Q=2) """ Decode a batch of images. Args: z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: decoded = self._decode(z) if not return_dict: return (decoded,) return decoded def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[2], b.shape[2], blend_extent) for y in range(blend_extent): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: blend_extent = min(a.shape[3], b.shape[3], blend_extent) for x in range(blend_extent): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the output, but they should be much less noticeable. Args: x (`torch.Tensor`): Input batch of images. Returns: `torch.Tensor`: The latent representation of the encoded videos. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent # Split the image into 512x512 tiles and encode them separately. rows = [] for i in range(0, x.shape[2], overlap_size): row = [] for j in range(0, x.shape[3], overlap_size): tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = self.encoder(tile) if self.config.use_quant_conv: tile = self.quant_conv(tile) row.append(tile) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) enc = torch.cat(result_rows, dim=2) return enc def tiled_encode(self, x: torch.Tensor, return_dict: bool = True): r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the output, but they should be much less noticeable. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent # Split the image into 512x512 tiles and encode them separately. rows = [] for i in range(0, x.shape[2], overlap_size): row = [] for j in range(0, x.shape[3], overlap_size): tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = self.encoder(tile) if self.config.use_quant_conv: tile = self.quant_conv(tile) row.append(tile) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) moments = torch.cat(result_rows, dim=2) return moments def tiled_decode(self, z: torch.Tensor, return_dict: bool = True): r""" Decode a batch of images using a tiled decoder. Args: z (`torch.Tensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) row_limit = self.tile_sample_min_size - blend_extent # Split z into overlapping 64x64 tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] for i in range(0, z.shape[2], overlap_size): row = [] for j in range(0, z.shape[3], overlap_size): tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] if self.config.use_post_quant_conv: tile = self.post_quant_conv(tile) decoded = self.decoder(tile) row.append(decoded) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) dec = torch.cat(result_rows, dim=2) if not return_dict: return (dec,) return dec def forward( self, sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ): r""" Args: sample (`torch.Tensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z).sample if not return_dict: return (dec,) return dec