This commit is contained in:
mi804
2026-04-23 16:52:59 +08:00
parent 1186379139
commit 394db06d86
7 changed files with 212 additions and 20 deletions

View File

@@ -349,7 +349,7 @@ class AttentionPooler(nn.Module):
) -> torch.Tensor:
B, T, P, D = x.shape
x = self.embed_tokens(x)
special_tokens = self.special_token.expand(B, T, 1, -1)
special_tokens = self.special_token.expand(B, T, 1, -1).to(x.device)
x = torch.cat([special_tokens, x], dim=2)
x = rearrange(x, "b t p c -> (b t) p c")