mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
39 lines
1.5 KiB
Python
39 lines
1.5 KiB
Python
import torch
|
|
|
|
|
|
class Attention(torch.nn.Module):
|
|
|
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
|
super().__init__()
|
|
dim_inner = head_dim * num_heads
|
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
|
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
|
|
batch_size = encoder_hidden_states.shape[0]
|
|
|
|
q = self.to_q(hidden_states)
|
|
k = self.to_k(encoder_hidden_states)
|
|
v = self.to_v(encoder_hidden_states)
|
|
|
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
|
hidden_states = hidden_states.transpose(1, 2).view(batch_size, -1, self.num_heads * self.head_dim)
|
|
hidden_states = hidden_states.to(q.dtype)
|
|
|
|
hidden_states = self.to_out(hidden_states)
|
|
|
|
return hidden_states
|