mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
Fix TeaCache bug with usp support integration and optimize memory usage by clearing attn cache
This commit is contained in:
@@ -124,4 +124,6 @@ def usp_attn_forward(self, x, freqs):
|
|||||||
)
|
)
|
||||||
x = x.flatten(2)
|
x = x.flatten(2)
|
||||||
|
|
||||||
|
del q, k, v
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return self.o(x)
|
return self.o(x)
|
||||||
@@ -396,13 +396,13 @@ def model_fn_wan_video(
|
|||||||
else:
|
else:
|
||||||
tea_cache_update = False
|
tea_cache_update = False
|
||||||
|
|
||||||
|
# blocks
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
# blocks
|
|
||||||
if use_unified_sequence_parallel:
|
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
|
||||||
for block in dit.blocks:
|
for block in dit.blocks:
|
||||||
x = block(x, context, t_mod, freqs)
|
x = block(x, context, t_mod, freqs)
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user