import inspect from typing import Any, Dict, List, Optional, Tuple, Union import torch, math import torch.nn as nn import torch.nn.functional as F import numpy as np from ..core.attention import attention_forward from ..core.gradient import gradient_checkpoint_forward def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ) -> torch.Tensor: """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = torch.nn.SiLU() if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift self.scale = scale def forward(self, timesteps: torch.Tensor) -> torch.Tensor: t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, scale=self.scale, ) return t_emb class AdaLayerNormContinuous(nn.Module): r""" Adaptive normalization layer with a norm layer (layer_norm or rms_norm). Args: embedding_dim (`int`): Embedding dimension to use during projection. conditioning_embedding_dim (`int`): Dimension of the input condition. elementwise_affine (`bool`, defaults to `True`): Boolean flag to denote if affine transformation should be applied. eps (`float`, defaults to 1e-5): Epsilon factor. bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. norm_type (`str`, defaults to `"layer_norm"`): Normalization layer to use. Values supported: "layer_norm", "rms_norm". """ def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters # because the output is immediately scaled and shifted by the projected conditioning embeddings. # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. # However, this is how it was implemented in the original code, and it's rather likely you should # set `elementwise_affine` to False. elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", ): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) if norm_type == "layer_norm": self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) scale, shift = torch.chunk(emb, 2, dim=1) x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] return x def get_1d_rotary_pos_embed( dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (`int`): Dimension of the frequency tensor. pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar theta (`float`, *optional*, defaults to 10000.0): Scaling factor for frequency computation. Defaults to 10000.0. use_real (`bool`, *optional*): If True, return real part and imaginary part separately. Otherwise, return complex numbers. linear_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the context extrapolation. Defaults to 1.0. ntk_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. repeat_interleave_real (`bool`, *optional*, defaults to `True`): If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. Otherwise, they are concateanted with themselves. freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): the dtype of the frequency tensor. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ assert dim % 2 == 0 if isinstance(pos, int): pos = torch.arange(pos) if isinstance(pos, np.ndarray): pos = torch.from_numpy(pos) # type: ignore # [S] theta = theta * ntk_factor freqs = ( 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor ) # [D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] is_npu = freqs.device.type == "npu" if is_npu: freqs = freqs.float() if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] return freqs_cos, freqs_sin elif use_real: # stable audio, allegro freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin else: # lumina freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis def apply_rotary_emb( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, sequence_dim: int = 2, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: x (`torch.Tensor`): Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ if use_real: cos, sin = freqs_cis # [S, D] if sequence_dim == 2: cos = cos[None, None, :, :] sin = sin[None, None, :, :] elif sequence_dim == 1: cos = cos[None, :, None, :] sin = sin[None, :, None, :] else: raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: # Used for flux, cogvideox, hunyuan-dit x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: # Used for Stable Audio, OmniGen, CogView4 and Cosmos x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out else: # used for lumina x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) return x_out.type_as(x) def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) encoder_query = encoder_key = encoder_value = None if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) return query, key, value, encoder_query, encoder_key, encoder_value def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) encoder_query = encoder_key = encoder_value = (None,) if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) return query, key, value, encoder_query, encoder_key, encoder_value def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): return _get_projections(attn, hidden_states, encoder_hidden_states) class Flux2SwiGLU(nn.Module): """ Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters. """ def __init__(self): super().__init__() self.gate_fn = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) x = self.gate_fn(x1) * x2 return x class Flux2FeedForward(nn.Module): def __init__( self, dim: int, dim_out: Optional[int] = None, mult: float = 3.0, inner_dim: Optional[int] = None, bias: bool = False, ): super().__init__() if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out or dim # Flux2SwiGLU will reduce the dimension by half self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias) self.act_fn = Flux2SwiGLU() self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear_in(x) x = self.act_fn(x) x = self.linear_out(x) return x class Flux2AttnProcessor: _attention_backend = None _parallel_config = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") def __call__( self, attn: "Flux2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( attn, hidden_states, encoder_hidden_states ) query = query.unflatten(-1, (attn.heads, -1)) key = key.unflatten(-1, (attn.heads, -1)) value = value.unflatten(-1, (attn.heads, -1)) query = attn.norm_q(query) key = attn.norm_k(key) if attn.added_kv_proj_dim is not None: encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) encoder_query = attn.norm_added_q(encoder_query) encoder_key = attn.norm_added_k(encoder_key) query = torch.cat([encoder_query, query], dim=1) key = torch.cat([encoder_key, key], dim=1) value = torch.cat([encoder_value, value], dim=1) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype) hidden_states = attention_forward( query, key, value, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states else: return hidden_states class Flux2Attention(torch.nn.Module): _default_processor_cls = Flux2AttnProcessor _available_processors = [Flux2AttnProcessor] def __init__( self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias: bool = False, added_kv_proj_dim: Optional[int] = None, added_proj_bias: Optional[bool] = True, out_bias: bool = True, eps: float = 1e-5, out_dim: int = None, elementwise_affine: bool = True, processor=None, ): super().__init__() self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim self.out_dim = out_dim if out_dim is not None else query_dim self.heads = out_dim // dim_head if out_dim is not None else heads self.use_bias = bias self.dropout = dropout self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) # QK Norm self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) self.to_out = torch.nn.ModuleList([]) self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(torch.nn.Dropout(dropout)) if added_kv_proj_dim is not None: self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) if processor is None: processor = self._default_processor_cls() self.processor = processor def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) class Flux2ParallelSelfAttnProcessor: _attention_backend = None _parallel_config = None def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") def __call__( self, attn: "Flux2ParallelSelfAttention", hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Parallel in (QKV + MLP in) projection hidden_states = attn.to_qkv_mlp_proj(hidden_states) qkv, mlp_hidden_states = torch.split( hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 ) # Handle the attention logic query, key, value = qkv.chunk(3, dim=-1) query = query.unflatten(-1, (attn.heads, -1)) key = key.unflatten(-1, (attn.heads, -1)) value = value.unflatten(-1, (attn.heads, -1)) query = attn.norm_q(query) key = attn.norm_k(key) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype) hidden_states = attention_forward( query, key, value, q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d", ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) # Handle the feedforward (FF) logic mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) # Concatenate and parallel output projection hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) hidden_states = attn.to_out(hidden_states) return hidden_states class Flux2ParallelSelfAttention(torch.nn.Module): """ Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF) input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. """ _default_processor_cls = Flux2ParallelSelfAttnProcessor _available_processors = [Flux2ParallelSelfAttnProcessor] # Does not support QKV fusion as the QKV projections are always fused _supports_qkv_fusion = False def __init__( self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias: bool = False, out_bias: bool = True, eps: float = 1e-5, out_dim: int = None, elementwise_affine: bool = True, mlp_ratio: float = 4.0, mlp_mult_factor: int = 2, processor=None, ): super().__init__() self.head_dim = dim_head self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.query_dim = query_dim self.out_dim = out_dim if out_dim is not None else query_dim self.heads = out_dim // dim_head if out_dim is not None else heads self.use_bias = bias self.dropout = dropout self.mlp_ratio = mlp_ratio self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) self.mlp_mult_factor = mlp_mult_factor # Fused QKV projections + MLP input projection self.to_qkv_mlp_proj = torch.nn.Linear( self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias ) self.mlp_act_fn = Flux2SwiGLU() # QK Norm self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) # Fused attention output projection + MLP output projection self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) if processor is None: processor = self._default_processor_cls() self.processor = processor def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) class Flux2SingleTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 3.0, eps: float = 1e-6, bias: bool = False, ): super().__init__() self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442) # for a visual depiction of this type of transformer block. self.attn = Flux2ParallelSelfAttention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=bias, out_bias=bias, eps=eps, mlp_ratio=mlp_ratio, mlp_mult_factor=2, processor=Flux2ParallelSelfAttnProcessor(), ) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor], temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, split_hidden_states: bool = False, text_seq_len: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already # concatenated if encoder_hidden_states is not None: text_seq_len = encoder_hidden_states.shape[1] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) mod_shift, mod_scale, mod_gate = temb_mod_params norm_hidden_states = self.norm(hidden_states) norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift joint_attention_kwargs = joint_attention_kwargs or {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) hidden_states = hidden_states + mod_gate * attn_output if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) if split_hidden_states: encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] return encoder_hidden_states, hidden_states else: return hidden_states class Flux2TransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 3.0, eps: float = 1e-6, bias: bool = False, ): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.attn = Flux2Attention( query_dim=dim, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=bias, added_proj_bias=bias, out_bias=bias, eps=eps, processor=Flux2AttnProcessor(), ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...], image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: joint_attention_kwargs = joint_attention_kwargs or {} # Modulation parameters shape: [1, 1, self.dim] (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt # Img stream norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa # Conditioning txt stream norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa # Attention on concatenated img + txt stream attention_outputs = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) attn_output, context_attn_output = attention_outputs # Process attention outputs for the image stream (`hidden_states`). attn_output = gate_msa * attn_output hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_mlp * ff_output # Process attention outputs for the text stream (`encoder_hidden_states`). context_attn_output = c_gate_msa * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) return encoder_hidden_states, hidden_states class Flux2PosEmbed(nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() self.theta = theta self.axes_dim = axes_dim def forward(self, ids: torch.Tensor) -> torch.Tensor: # Expected ids shape: [S, len(self.axes_dim)] cos_out = [] sin_out = [] pos = ids.float() is_mps = ids.device.type == "mps" is_npu = ids.device.type == "npu" freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1] for i in range(len(self.axes_dim)): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], pos[..., i], theta=self.theta, repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype, ) cos_out.append(cos) sin_out.append(sin) freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) return freqs_cos, freqs_sin class Flux2TimestepGuidanceEmbeddings(nn.Module): def __init__( self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False, guidance_embeds: bool = True, ): super().__init__() self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding( in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias ) if guidance_embeds: self.guidance_embedder = TimestepEmbedding( in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias ) else: self.guidance_embedder = None def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) if guidance is not None and self.guidance_embedder is not None: guidance_proj = self.time_proj(guidance) guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) time_guidance_emb = timesteps_emb + guidance_emb return time_guidance_emb else: return timesteps_emb class Flux2Modulation(nn.Module): def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): super().__init__() self.mod_param_sets = mod_param_sets self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) self.act_fn = nn.SiLU() def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: mod = self.act_fn(temb) mod = self.linear(mod) if mod.ndim == 2: mod = mod.unsqueeze(1) mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1) # Return tuple of 3-tuples of modulation params shift/scale/gate return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) class Flux2DiT(torch.nn.Module): def __init__( self, patch_size: int = 1, in_channels: int = 128, out_channels: Optional[int] = None, num_layers: int = 8, num_single_layers: int = 48, attention_head_dim: int = 128, num_attention_heads: int = 48, joint_attention_dim: int = 15360, timestep_guidance_channels: int = 256, mlp_ratio: float = 3.0, axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), rope_theta: int = 2000, eps: float = 1e-6, guidance_embeds: bool = True, ): super().__init__() self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim # 1. Sinusoidal positional embedding for RoPE on image and text tokens self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) # 2. Combined timestep + guidance embedding self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False, guidance_embeds=guidance_embeds, ) # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) # 4. Input projections self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) # 5. Double Stream Transformer Blocks self.transformer_blocks = nn.ModuleList( [ Flux2TransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, mlp_ratio=mlp_ratio, eps=eps, bias=False, ) for _ in range(num_layers) ] ) # 6. Single Stream Transformer Blocks self.single_transformer_blocks = nn.ModuleList( [ Flux2SingleTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, mlp_ratio=mlp_ratio, eps=eps, bias=False, ) for _ in range(num_single_layers) ] ) # 7. Output layers self.norm_out = AdaLayerNormContinuous( self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False ) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, ): # 0. Handle input arguments if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 num_txt_tokens = encoder_hidden_states.shape[1] # 1. Calculate timestep embedding and modulation parameters timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 temb = self.time_guidance_embed(timestep, guidance) double_stream_mod_img = self.double_stream_modulation_img(temb) double_stream_mod_txt = self.double_stream_modulation_txt(temb) single_stream_mod = self.single_stream_modulation(temb)[0] # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states) hidden_states = self.x_embedder(hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states) # 3. Calculate RoPE embeddings from image and text tokens # NOTE: the below logic means that we can't support batched inference with images of different resolutions or # text prompts of differents lengths. Is this a use case we want to support? if img_ids.ndim == 3: img_ids = img_ids[0] if txt_ids.ndim == 3: txt_ids = txt_ids[0] image_rotary_emb = self.pos_embed(img_ids) text_rotary_emb = self.pos_embed(txt_ids) concat_rotary_emb = ( torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), ) # 4. Double Stream Transformer Blocks for index_block, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = gradient_checkpoint_forward( block, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb_mod_params_img=double_stream_mod_img, temb_mod_params_txt=double_stream_mod_txt, image_rotary_emb=concat_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) # Concatenate text and image streams for single-block inference hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) # 5. Single Stream Transformer Blocks for index_block, block in enumerate(self.single_transformer_blocks): hidden_states = gradient_checkpoint_forward( block, use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, hidden_states=hidden_states, encoder_hidden_states=None, temb_mod_params=single_stream_mod, image_rotary_emb=concat_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) # Remove text tokens from concatenated stream hidden_states = hidden_states[:, num_txt_tokens:, ...] # 6. Output layers hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) return output