mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
support wan tensor parallel (preview)
This commit is contained in:
@@ -108,6 +108,16 @@ class RMSNorm(nn.Module):
|
||||
return self.norm(x.float()).to(dtype) * self.weight
|
||||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
def __init__(self, num_heads):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, q, k, v):
|
||||
x = flash_attention(q=q, k=k, v=v, num_heads=self.num_heads)
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@@ -121,17 +131,16 @@ class SelfAttention(nn.Module):
|
||||
self.o = nn.Linear(dim, dim)
|
||||
self.norm_q = RMSNorm(dim, eps=eps)
|
||||
self.norm_k = RMSNorm(dim, eps=eps)
|
||||
|
||||
self.attn = AttentionModule(self.num_heads)
|
||||
|
||||
def forward(self, x, freqs):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
v = self.v(x)
|
||||
x = flash_attention(
|
||||
q=rope_apply(q, freqs, self.num_heads),
|
||||
k=rope_apply(k, freqs, self.num_heads),
|
||||
v=v,
|
||||
num_heads=self.num_heads
|
||||
)
|
||||
q = rope_apply(q, freqs, self.num_heads)
|
||||
k = rope_apply(k, freqs, self.num_heads)
|
||||
x = self.attn(q, k, v)
|
||||
return self.o(x)
|
||||
|
||||
|
||||
@@ -153,6 +162,8 @@ class CrossAttention(nn.Module):
|
||||
self.k_img = nn.Linear(dim, dim)
|
||||
self.v_img = nn.Linear(dim, dim)
|
||||
self.norm_k_img = RMSNorm(dim, eps=eps)
|
||||
|
||||
self.attn = AttentionModule(self.num_heads)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
if self.has_image_input:
|
||||
@@ -163,7 +174,7 @@ class CrossAttention(nn.Module):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(ctx))
|
||||
v = self.v(ctx)
|
||||
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
||||
x = self.attn(q, k, v)
|
||||
if self.has_image_input:
|
||||
k_img = self.norm_k_img(self.k_img(img))
|
||||
v_img = self.v_img(img)
|
||||
|
||||
Reference in New Issue
Block a user