Merge pull request #1256 from Feng0w0/npu_fused

[model][NPU]:Add NPU fusion operator patch to Zimage model to improve performance
This commit is contained in:
Zhongjie Duan
2026-02-09 20:08:44 +08:00
committed by GitHub
5 changed files with 67 additions and 14 deletions

View File

@@ -88,6 +88,14 @@ class Attention(torch.nn.Module):
self.norm_q = RMSNorm(head_dim, eps=1e-5)
self.norm_k = RMSNorm(head_dim, eps=1e-5)
# Apply RoPE
def apply_rotary_emb(self, x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
def forward(self, hidden_states, freqs_cis, attention_mask):
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
@@ -103,17 +111,9 @@ class Attention(torch.nn.Module):
if self.norm_k is not None:
key = self.norm_k(key)
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast(get_device_type(), enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
query = self.apply_rotary_emb(query, freqs_cis)
key = self.apply_rotary_emb(key, freqs_cis)
# Cast to correct dtype
dtype = query.dtype