This commit is contained in:
co63oc
2025-02-26 14:18:36 +08:00
parent bed770248b
commit 4268f5466b
9 changed files with 14 additions and 14 deletions

View File

@@ -980,7 +980,7 @@ class Embedding(torch.nn.Module):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
embeddings = words_embeddings
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# Data format change to avoid explicit transposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection: