Fix num_frames in i2v (#339)

* Fix num_frames in i2v

* Remove print in flash_attention
This commit is contained in:
Kohaku-Blueleaf
2025-02-26 10:05:51 +08:00
committed by GitHub
parent af7d305f00
commit 020560d2b5
2 changed files with 4 additions and 6 deletions

View File

@@ -112,7 +112,6 @@ def flash_attention(
causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq))
elif FLASH_ATTN_2_AVAILABLE:
print(q_lens, lq, k_lens, lk, causal, window_size)
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
@@ -128,7 +127,6 @@ def flash_attention(
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
print(x.shape)
else:
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
k = k.unsqueeze(0).transpose(1, 2).to(dtype)