This commit is contained in:
co63oc
2025-02-26 14:18:36 +08:00
parent bed770248b
commit 4268f5466b
9 changed files with 14 additions and 14 deletions

View File

@@ -88,7 +88,7 @@ class LLaMaEmbedding(nn.Module):
embeddings = embeddings.to(self.params_dtype)
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# Data format change to avoid explicit transposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
@@ -326,7 +326,7 @@ class MultiQueryAttention(nn.Module):
dim=-1,
)
# gather on 1st dimention
# gather on 1st dimension
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
xk, xv = xkv.chunk(2, -1)
@@ -357,7 +357,7 @@ class MultiQueryAttention(nn.Module):
output = self.core_attention(xq, xk, xv,
cu_seqlens=cu_seqlens,
max_seq_len=max_seq_len)
# reduce-scatter only support first dimention now
# reduce-scatter only support first dimension now
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
else:
xq, xk, xv = [