feat: sp for wan

This commit is contained in:
Jinzhe Pan
2025-03-17 08:31:45 +00:00
parent 39890f023f
commit 42cb7d96bb
5 changed files with 175 additions and 11 deletions

View File

@@ -90,6 +90,8 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
print(f"x_out.shape: {x_out.shape}, freqs.shape: {freqs.shape}")
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)