Fix TeaCache bug with usp support integration and optimize memory usage by clearing attn cache

This commit is contained in:
calmhawk
2025-03-30 01:13:34 +08:00
parent c7035ad911
commit 52896fa8dd
2 changed files with 6 additions and 4 deletions

View File

@@ -124,4 +124,6 @@ def usp_attn_forward(self, x, freqs):
)
x = x.flatten(2)
del q, k, v
torch.cuda.empty_cache()
return self.o(x)

View File

@@ -396,13 +396,13 @@ def model_fn_wan_video(
else:
tea_cache_update = False
# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
if tea_cache_update:
x = tea_cache.update(x)
else:
# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
for block in dit.blocks:
x = block(x, context, t_mod, freqs)
if tea_cache is not None: