mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
[feat] add VACE sequence parallel (#1345)
* add VACE sequence parallel * resolve conflict --------- Co-authored-by: yuan <yuan@yuandeMacBook-Pro.local> Co-authored-by: Hong Zhang <41229682+mi804@users.noreply.github.com>
This commit is contained in:
@@ -86,7 +86,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
def enable_usp(self):
|
def enable_usp(self):
|
||||||
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward
|
from ..utils.xfuser import get_sequence_parallel_world_size, usp_attn_forward, usp_dit_forward, usp_vace_forward
|
||||||
|
|
||||||
for block in self.dit.blocks:
|
for block in self.dit.blocks:
|
||||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
@@ -95,6 +95,14 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
for block in self.dit2.blocks:
|
for block in self.dit2.blocks:
|
||||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
||||||
|
if self.vace is not None:
|
||||||
|
for block in self.vace.vace_blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
self.vace.forward = types.MethodType(usp_vace_forward, self.vace)
|
||||||
|
if self.vace2 is not None:
|
||||||
|
for block in self.vace2.vace_blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
self.vace2.forward = types.MethodType(usp_vace_forward, self.vace2)
|
||||||
self.sp_size = get_sequence_parallel_world_size()
|
self.sp_size = get_sequence_parallel_world_size()
|
||||||
self.use_unified_sequence_parallel = True
|
self.use_unified_sequence_parallel = True
|
||||||
|
|
||||||
@@ -1450,13 +1458,6 @@ def model_fn_wan_video(
|
|||||||
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||||
else:
|
else:
|
||||||
tea_cache_update = False
|
tea_cache_update = False
|
||||||
|
|
||||||
if vace_context is not None:
|
|
||||||
vace_hints = vace(
|
|
||||||
x, vace_context, context, t_mod, freqs,
|
|
||||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
||||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
|
||||||
)
|
|
||||||
|
|
||||||
# WanToDance
|
# WanToDance
|
||||||
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
|
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
|
||||||
@@ -1519,6 +1520,13 @@ def model_fn_wan_video(
|
|||||||
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
pad_shape = chunks[0].shape[1] - chunks[-1].shape[1]
|
||||||
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
chunks = [torch.nn.functional.pad(chunk, (0, 0, 0, chunks[0].shape[1]-chunk.shape[1]), value=0) for chunk in chunks]
|
||||||
x = chunks[get_sequence_parallel_rank()]
|
x = chunks[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
if vace_context is not None:
|
||||||
|
vace_hints = vace(
|
||||||
|
x, vace_context, context, t_mod, freqs,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
|
||||||
|
)
|
||||||
if tea_cache_update:
|
if tea_cache_update:
|
||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
@@ -1561,9 +1569,6 @@ def model_fn_wan_video(
|
|||||||
# VACE
|
# VACE
|
||||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
|
||||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
|
||||||
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
|
||||||
x = x + current_vace_hint * vace_scale
|
x = x + current_vace_hint * vace_scale
|
||||||
|
|
||||||
# Animate
|
# Animate
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
|
from .xdit_context_parallel import usp_attn_forward, usp_dit_forward, usp_vace_forward, get_sequence_parallel_world_size, initialize_usp, get_current_chunk, gather_all_chunks
|
||||||
|
|||||||
@@ -117,6 +117,39 @@ def usp_dit_forward(self,
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def usp_vace_forward(
|
||||||
|
self, x, vace_context, context, t_mod, freqs,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
):
|
||||||
|
# Compute full sequence length from the sharded x
|
||||||
|
full_seq_len = x.shape[1] * get_sequence_parallel_world_size()
|
||||||
|
|
||||||
|
# Embed vace_context via patch embedding
|
||||||
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
c = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in c
|
||||||
|
])
|
||||||
|
|
||||||
|
# Chunk VACE context along sequence dim BEFORE processing through blocks
|
||||||
|
c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
# Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
|
||||||
|
for block in self.vace_blocks:
|
||||||
|
c = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload,
|
||||||
|
c, x, context, t_mod, freqs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hints are already sharded per-rank
|
||||||
|
hints = torch.unbind(c)[:-1]
|
||||||
|
return hints
|
||||||
|
|
||||||
|
|
||||||
def usp_attn_forward(self, x, freqs):
|
def usp_attn_forward(self, x, freqs):
|
||||||
q = self.norm_q(self.q(x))
|
q = self.norm_q(self.q(x))
|
||||||
k = self.norm_k(self.k(x))
|
k = self.norm_k(self.k(x))
|
||||||
|
|||||||
Reference in New Issue
Block a user