mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
113 lines
4.6 KiB
Python
113 lines
4.6 KiB
Python
import torch
|
|
|
|
|
|
class Attention(torch.nn.Module):
|
|
|
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
|
super().__init__()
|
|
dim_inner = head_dim * num_heads
|
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
|
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
|
|
batch_size = encoder_hidden_states.shape[0]
|
|
|
|
q = self.to_q(hidden_states)
|
|
k = self.to_k(encoder_hidden_states)
|
|
v = self.to_v(encoder_hidden_states)
|
|
|
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
hidden_states = hidden_states.to(q.dtype)
|
|
|
|
hidden_states = self.to_out(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class CLIPEncoderLayer(torch.nn.Module):
|
|
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
|
super().__init__()
|
|
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
|
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
|
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
|
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
|
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
|
|
|
self.use_quick_gelu = use_quick_gelu
|
|
|
|
def quickGELU(self, x):
|
|
return x * torch.sigmoid(1.702 * x)
|
|
|
|
def forward(self, hidden_states, attn_mask=None):
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.layer_norm1(hidden_states)
|
|
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.layer_norm2(hidden_states)
|
|
hidden_states = self.fc1(hidden_states)
|
|
if self.use_quick_gelu:
|
|
hidden_states = self.quickGELU(hidden_states)
|
|
else:
|
|
hidden_states = torch.nn.functional.gelu(hidden_states)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class FluxTextEncoderClip(torch.nn.Module):
|
|
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
|
super().__init__()
|
|
|
|
# token_embedding
|
|
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
|
|
|
# position_embeds (This is a fixed tensor)
|
|
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
|
|
|
# encoders
|
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
|
|
|
# attn_mask
|
|
self.attn_mask = self.attention_mask(max_position_embeddings)
|
|
|
|
# final_layer_norm
|
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
|
|
|
def attention_mask(self, length):
|
|
mask = torch.empty(length, length)
|
|
mask.fill_(float("-inf"))
|
|
mask.triu_(1)
|
|
return mask
|
|
|
|
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
|
embeds = self.token_embedding(input_ids)
|
|
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
|
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
|
if extra_mask is not None:
|
|
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
|
for encoder_id, encoder in enumerate(self.encoders):
|
|
embeds = encoder(embeds, attn_mask=attn_mask)
|
|
if encoder_id + clip_skip == len(self.encoders):
|
|
hidden_states = embeds
|
|
embeds = self.final_layer_norm(embeds)
|
|
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
|
return pooled_embeds, hidden_states
|