mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Mova (#1337)
* support mova inference * mova media_io * add unified audio_video api & fix bug of mono audio input for ltx * support mova train * mova docs * fix bug
This commit is contained in:
@@ -1 +1 @@
|
||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp
|
||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
|
||||
|
||||
@@ -143,4 +143,31 @@ def usp_attn_forward(self, x, freqs):
|
||||
|
||||
del q, k, v
|
||||
getattr(torch, parse_device_type(x.device)).empty_cache()
|
||||
return self.o(x)
|
||||
return self.o(x)
|
||||
|
||||
|
||||
def get_current_chunk(x, dim=1):
|
||||
chunks = torch.chunk(x, get_sequence_parallel_world_size(), dim=dim)
|
||||
ndims = len(chunks[0].shape)
|
||||
pad_list = [0] * (2 * ndims)
|
||||
pad_end_index = 2 * (ndims - 1 - dim) + 1
|
||||
max_size = chunks[0].size(dim)
|
||||
chunks = [
|
||||
torch.nn.functional.pad(
|
||||
chunk,
|
||||
tuple(pad_list[:pad_end_index] + [max_size - chunk.size(dim)] + pad_list[pad_end_index+1:]),
|
||||
value=0
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
x = chunks[get_sequence_parallel_rank()]
|
||||
return x
|
||||
|
||||
|
||||
def gather_all_chunks(x, seq_len=None, dim=1):
|
||||
x = get_sp_group().all_gather(x, dim=dim)
|
||||
if seq_len is not None:
|
||||
slices = [slice(None)] * x.ndim
|
||||
slices[dim] = slice(0, seq_len)
|
||||
x = x[tuple(slices)]
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user