feat: sp for wan

This commit is contained in:
Jinzhe Pan
2025-03-17 08:31:45 +00:00
parent 39890f023f
commit 42cb7d96bb
5 changed files with 175 additions and 11 deletions

View File

@@ -0,0 +1,127 @@
import torch
from typing import Optional
from einops import rearrange
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(
pad_size,
s1,
s2,
dtype=original_tensor.dtype,
device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
s_per_rank = x.shape[1]
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype)
def usp_dit_forward(self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
):
t = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep))
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
context = self.text_embedding(context)
if self.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
x, (f, h, w) = self.patchify(x)
freqs = torch.cat([
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# Context Parallel
x = torch.chunk(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
for block in self.blocks:
if self.training and use_gradient_checkpointing:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x, context, t_mod, freqs,
use_reentrant=False,
)
else:
x = block(x, context, t_mod, freqs)
x = self.head(x, t)
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
# unpatchify
x = self.unpatchify(x, (f, h, w))
return x
def usp_attn_forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
x = xFuserLongContextAttention()(
None,
query=q,
key=k,
value=v,
)
x = x.flatten(2)
return self.o(x)

View File

@@ -90,6 +90,8 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape( x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2)) x.shape[0], x.shape[1], x.shape[2], -1, 2))
print(f"x_out.shape: {x_out.shape}, freqs.shape: {freqs.shape}")
x_out = torch.view_as_real(x_out * freqs).flatten(2) x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype) return x_out.to(x.dtype)

View File

@@ -1,3 +1,4 @@
import types
from ..models import ModelManager from ..models import ModelManager
from ..models.wan_video_dit import WanModel from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_text_encoder import WanTextEncoder
@@ -12,6 +13,10 @@ import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from typing import Optional from typing import Optional
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
@@ -135,11 +140,19 @@ class WanVideoPipeline(BasePipeline):
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None): def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
if device is None: device = model_manager.device if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager) pipe.fetch_models(model_manager)
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
for block in pipe.dit.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
pipe.sp_size = get_sequence_parallel_world_size()
return pipe return pipe
@@ -375,11 +388,15 @@ def model_fn_wan_video(
x = tea_cache.update(x) x = tea_cache.update(x)
else: else:
# blocks # blocks
if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
for block in dit.blocks: for block in dit.blocks:
x = block(x, context, t_mod, freqs) x = block(x, context, t_mod, freqs)
if tea_cache is not None: if tea_cache is not None:
tea_cache.store(x) tea_cache.store(x)
x = dit.head(x, t) x = dit.head(x, t)
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w)) x = dit.unpatchify(x, (f, h, w))
return x return x

View File

@@ -1,29 +1,47 @@
import torch import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download from modelscope import snapshot_download
import torch.distributed as dist
# Download models # Download models
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") # snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
# Load models # Load models
model_manager = ModelManager(device="cpu") model_manager = ModelManager(device="cpu")
model_manager.load_models( model_manager.load_models(
[ [
[ [
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
], ],
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", "/demo-huabei2/models/dit/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", "/demo-huabei2/models/dit/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
], ],
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
) )
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
dist.init_process_group(
backend="nccl",
init_method="env://",
)
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device=f"cuda:{dist.get_rank()}", use_usp=True if dist.get_world_size() > 1 else False)
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
# Text-to-video # Text-to-video