mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 06:46:13 +00:00
codes
This commit is contained in:
@@ -349,7 +349,7 @@ class AttentionPooler(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
B, T, P, D = x.shape
|
||||
x = self.embed_tokens(x)
|
||||
special_tokens = self.special_token.expand(B, T, 1, -1)
|
||||
special_tokens = self.special_token.expand(B, T, 1, -1).to(x.device)
|
||||
x = torch.cat([special_tokens, x], dim=2)
|
||||
x = rearrange(x, "b t p c -> (b t) p c")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user