feat: sp for wan

This commit is contained in:
Jinzhe Pan
2025-03-17 08:31:45 +00:00
parent 39890f023f
commit 42cb7d96bb
5 changed files with 175 additions and 11 deletions

View File

@@ -1,3 +1,4 @@
import types
from ..models import ModelManager
from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder
@@ -12,6 +13,10 @@ import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
@@ -135,11 +140,19 @@ class WanVideoPipeline(BasePipeline):
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
for block in pipe.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
pipe.sp_size = get_sequence_parallel_world_size()
return pipe
@@ -375,11 +388,15 @@ def model_fn_wan_video(
x = tea_cache.update(x)
else:
# blocks
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
for block in dit.blocks:
x = block(x, context, t_mod, freqs)
if tea_cache is not None:
tea_cache.store(x)
x = dit.head(x, t)
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w))
return x