minor fix

This commit is contained in:
mi804
2025-08-28 10:13:52 +08:00
parent caa17da5b9
commit 9cea10cc69
6 changed files with 7 additions and 76 deletions

View File

@@ -562,7 +562,7 @@ class WanS2VModel(torch.nn.Module):
context,
t_mod,
seq_len_x,
pre_compute_freqs,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
@@ -577,7 +577,7 @@ class WanS2VModel(torch.nn.Module):
context,
t_mod,
seq_len_x,
pre_compute_freqs,
pre_compute_freqs[0],
use_reentrant=False,
)
x = torch.utils.checkpoint.checkpoint(
@@ -586,7 +586,7 @@ class WanS2VModel(torch.nn.Module):
use_reentrant=False,
)
else:
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs)
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
x = self.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
x = x[:, :seq_len_x]