This commit is contained in:
Artiprocher
2025-11-27 22:43:43 +08:00
parent 0b527c460f
commit 0b72c2b3ba
10 changed files with 1329 additions and 17 deletions

View File

@@ -150,25 +150,75 @@ class ConvAttention(torch.nn.Module):
return hidden_states
class Attention(torch.nn.Module):
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
super().__init__()
dim_inner = head_dim * num_heads
kv_dim = kv_dim if kv_dim is not None else q_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
batch_size = encoder_hidden_states.shape[0]
q = self.to_q(hidden_states)
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states = self.to_out(hidden_states)
return hidden_states
class VAEAttentionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
self.transformer_blocks = torch.nn.ModuleList([
ConvAttention(
inner_dim,
num_attention_heads,
attention_head_dim,
bias_q=True,
bias_kv=True,
bias_out=True
)
for d in range(num_layers)
])
if use_conv_attention:
self.transformer_blocks = torch.nn.ModuleList([
ConvAttention(
inner_dim,
num_attention_heads,
attention_head_dim,
bias_q=True,
bias_kv=True,
bias_out=True
)
for d in range(num_layers)
])
else:
self.transformer_blocks = torch.nn.ModuleList([
Attention(
inner_dim,
num_attention_heads,
attention_head_dim,
bias_q=True,
bias_kv=True,
bias_out=True
)
for d in range(num_layers)
])
def forward(self, hidden_states, time_emb, text_emb, res_stack):
batch, _, height, width = hidden_states.shape
@@ -244,7 +294,7 @@ class DownSampler(torch.nn.Module):
class FluxVAEDecoder(torch.nn.Module):
def __init__(self):
def __init__(self, use_conv_attention=True):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@@ -253,7 +303,7 @@ class FluxVAEDecoder(torch.nn.Module):
self.blocks = torch.nn.ModuleList([
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
ResnetBlock(512, 512, eps=1e-6),
# UpDecoderBlock2D
ResnetBlock(512, 512, eps=1e-6),
@@ -316,7 +366,7 @@ class FluxVAEDecoder(torch.nn.Module):
class FluxVAEEncoder(torch.nn.Module):
def __init__(self):
def __init__(self, use_conv_attention=True):
super().__init__()
self.scaling_factor = 0.3611
self.shift_factor = 0.1159
@@ -340,7 +390,7 @@ class FluxVAEEncoder(torch.nn.Module):
ResnetBlock(512, 512, eps=1e-6),
# UNetMidBlock2D
ResnetBlock(512, 512, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
ResnetBlock(512, 512, eps=1e-6),
])