[feat] add VACE sequence parallel (#1345)

* add VACE sequence parallel

* resolve conflict

---------

Co-authored-by: yuan <yuan@yuandeMacBook-Pro.local>
Co-authored-by: Hong Zhang <41229682+mi804@users.noreply.github.com>
This commit is contained in:
Cao Yuan
2026-03-23 15:46:27 +08:00
committed by GitHub
parent 5bccd60c80
commit 5d198287f0
3 changed files with 50 additions and 12 deletions

View File

@@ -117,6 +117,39 @@ def usp_dit_forward(self,
return x
def usp_vace_forward(
self, x, vace_context, context, t_mod, freqs,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
):
# Compute full sequence length from the sharded x
full_seq_len = x.shape[1] * get_sequence_parallel_world_size()
# Embed vace_context via patch embedding
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# Chunk VACE context along sequence dim BEFORE processing through blocks
c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
# Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
for block in self.vace_blocks:
c = gradient_checkpoint_forward(
block,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
c, x, context, t_mod, freqs
)
# Hints are already sharded per-rank
hints = torch.unbind(c)[:-1]
return hints
def usp_attn_forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))