mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
[feat] add VACE sequence parallel (#1345)
* add VACE sequence parallel * resolve conflict --------- Co-authored-by: yuan <yuan@yuandeMacBook-Pro.local> Co-authored-by: Hong Zhang <41229682+mi804@users.noreply.github.com>
This commit is contained in:
@@ -117,6 +117,39 @@ def usp_dit_forward(self,
|
||||
return x
|
||||
|
||||
|
||||
def usp_vace_forward(
|
||||
self, x, vace_context, context, t_mod, freqs,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
):
|
||||
# Compute full sequence length from the sharded x
|
||||
full_seq_len = x.shape[1] * get_sequence_parallel_world_size()
|
||||
|
||||
# Embed vace_context via patch embedding
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, full_seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
# Chunk VACE context along sequence dim BEFORE processing through blocks
|
||||
c = torch.chunk(c, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
# Process through vace_blocks (self_attn already monkey-patched to usp_attn_forward)
|
||||
for block in self.vace_blocks:
|
||||
c = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
c, x, context, t_mod, freqs
|
||||
)
|
||||
|
||||
# Hints are already sharded per-rank
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
|
||||
def usp_attn_forward(self, x, freqs):
|
||||
q = self.norm_q(self.q(x))
|
||||
k = self.norm_k(self.k(x))
|
||||
|
||||
Reference in New Issue
Block a user