mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 18:58:11 +00:00
fea : enable wan video usp for arbitrary seq len
This commit is contained in:
@@ -26,15 +26,12 @@ def pad_freqs(original_tensor, target_len):
|
|||||||
|
|
||||||
def rope_apply(x, freqs, num_heads):
|
def rope_apply(x, freqs, num_heads):
|
||||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
s_per_rank = x.shape[1]
|
|
||||||
|
|
||||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
|
|
||||||
sp_size = get_sequence_parallel_world_size()
|
|
||||||
sp_rank = get_sequence_parallel_rank()
|
sp_rank = get_sequence_parallel_rank()
|
||||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
freqs_rank = torch.chunk(freqs, dim=0)[sp_rank] # chunk freqs like x
|
||||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
|
||||||
|
|
||||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
@@ -73,6 +70,9 @@ def usp_dit_forward(self,
|
|||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
# Context Parallel
|
# Context Parallel
|
||||||
|
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||||
|
seq_lens = [chunk.shape[1] for chunk in chunks]
|
||||||
|
|
||||||
x = torch.chunk(
|
x = torch.chunk(
|
||||||
x, get_sequence_parallel_world_size(),
|
x, get_sequence_parallel_world_size(),
|
||||||
dim=1)[get_sequence_parallel_rank()]
|
dim=1)[get_sequence_parallel_rank()]
|
||||||
@@ -98,7 +98,15 @@ def usp_dit_forward(self,
|
|||||||
x = self.head(x, t)
|
x = self.head(x, t)
|
||||||
|
|
||||||
# Context Parallel
|
# Context Parallel
|
||||||
|
max_len = seq_lens[0]
|
||||||
|
b, s, c = x.shape
|
||||||
|
if s != max_len:
|
||||||
|
padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device)
|
||||||
|
x = torch.cat([x, padding_tensor], dim=1)
|
||||||
|
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
# remove pad
|
||||||
|
x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)])
|
||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, (f, h, w))
|
x = self.unpatchify(x, (f, h, w))
|
||||||
|
|||||||
@@ -594,7 +594,12 @@ def model_fn_wan_video(
|
|||||||
# blocks
|
# blocks
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)
|
||||||
|
seq_lens = [chunk.shape[1] for chunk in chunks]
|
||||||
|
x = torch.chunk(
|
||||||
|
x, get_sequence_parallel_world_size(),
|
||||||
|
dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
@@ -612,6 +617,14 @@ def model_fn_wan_video(
|
|||||||
x = dit.head(x, t)
|
x = dit.head(x, t)
|
||||||
if use_unified_sequence_parallel:
|
if use_unified_sequence_parallel:
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
max_len = seq_lens[0]
|
||||||
|
b, s, c = x.shape
|
||||||
|
if s != max_len:
|
||||||
|
padding_tensor = torch.ones(b, max_len - s, c, dtype=x.dtype, device=x.device)
|
||||||
|
x = torch.cat([x, padding_tensor], dim=1)
|
||||||
|
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
# remove pad
|
||||||
|
x = torch.cat([x[:,max_len*id:seq_lens[id],:] for id in range(seq_lens)])
|
||||||
x = dit.unpatchify(x, (f, h, w))
|
x = dit.unpatchify(x, (f, h, w))
|
||||||
return x
|
return x
|
||||||
|
|||||||
Reference in New Issue
Block a user