mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
wans2v usp
This commit is contained in:
@@ -459,10 +459,13 @@ class WanS2VModel(torch.nn.Module):
|
||||
)
|
||||
return x, rope_embs, mask_input
|
||||
|
||||
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len):
|
||||
def after_transformer_block(self, block_idx, hidden_states, audio_emb_global, audio_emb, original_seq_len, use_unified_sequence_parallel=False):
|
||||
if block_idx in self.audio_injector.injected_block_id.keys():
|
||||
audio_attn_id = self.audio_injector.injected_block_id[block_idx]
|
||||
num_frames = audio_emb.shape[1]
|
||||
if use_unified_sequence_parallel:
|
||||
from xfuser.core.distributed import get_sp_group
|
||||
hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
|
||||
|
||||
input_hidden_states = hidden_states[:, :original_seq_len].clone() # b (f h w) c
|
||||
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
|
||||
@@ -476,7 +479,9 @@ class WanS2VModel(torch.nn.Module):
|
||||
residual_out = self.audio_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
|
||||
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||
hidden_states[:, :original_seq_len] = hidden_states[:, :original_seq_len] + residual_out
|
||||
|
||||
if use_unified_sequence_parallel:
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size, get_sequence_parallel_rank
|
||||
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
return hidden_states
|
||||
|
||||
def cal_audio_emb(self, audio_input, motion_frames=[73, 19]):
|
||||
|
||||
@@ -1284,11 +1284,11 @@ def model_fn_wans2v(
|
||||
|
||||
# x and pose_cond
|
||||
pose_cond = torch.zeros_like(x) if pose_cond is None else pose_cond
|
||||
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond)) # torch.Size([1, 29120, 5120])
|
||||
seq_len_x = x.shape[1]
|
||||
x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(pose_cond))
|
||||
seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel
|
||||
|
||||
# reference image
|
||||
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) # torch.Size([1, 1456, 5120])
|
||||
ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents))
|
||||
grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw))
|
||||
x = torch.cat([x, ref_latents], dim=1)
|
||||
# mask
|
||||
@@ -1305,6 +1305,14 @@ def model_fn_wans2v(
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)).unsqueeze(2).transpose(0, 2)
|
||||
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
world_size, sp_rank = get_sequence_parallel_world_size(), get_sequence_parallel_rank()
|
||||
assert x.shape[1] % world_size == 0, f"the dimension after chunk must be divisible by world size, but got {x.shape[1]} and {get_sequence_parallel_world_size()}"
|
||||
x = torch.chunk(x, world_size, dim=1)[sp_rank]
|
||||
seg_idxs = [0] + list(torch.cumsum(torch.tensor([x.shape[1]] * world_size), dim=0).cpu().numpy())
|
||||
seq_len_x_list = [min(max(0, seq_len_x - seg_idxs[i]), x.shape[1]) for i in range(len(seg_idxs)-1)]
|
||||
seq_len_x = seq_len_x_list[sp_rank]
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
@@ -1315,7 +1323,7 @@ def model_fn_wans2v(
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs,
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
@@ -1326,7 +1334,7 @@ def model_fn_wans2v(
|
||||
elif use_gradient_checkpointing:
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs,
|
||||
x, context, t_mod, seq_len_x, pre_compute_freqs[0],
|
||||
use_reentrant=False,
|
||||
)
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
@@ -1335,10 +1343,13 @@ def model_fn_wans2v(
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs)
|
||||
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x)
|
||||
x = block(x, context, t_mod, seq_len_x, pre_compute_freqs[0])
|
||||
x = dit.after_transformer_block(block_id, x, audio_emb_global, merged_audio_emb, seq_len_x_global, use_unified_sequence_parallel)
|
||||
|
||||
x = x[:, :seq_len_x]
|
||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
x = get_sp_group().all_gather(x, dim=1)
|
||||
|
||||
x = x[:, :seq_len_x_global]
|
||||
x = dit.head(x, t[:-1])
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
# make compatible with wan video
|
||||
|
||||
Reference in New Issue
Block a user