fix usp dependency

This commit is contained in:
Artiprocher
2025-03-25 19:26:24 +08:00
parent d0fed6ba72
commit 4e43d4d461
5 changed files with 29 additions and 15 deletions

View File

@@ -90,7 +90,6 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape( x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2)) x.shape[0], x.shape[1], x.shape[2], -1, 2))
x_out = torch.view_as_real(x_out * freqs).flatten(2) x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype) return x_out.to(x.dtype)

View File

@@ -13,10 +13,6 @@ import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from typing import Optional from typing import Optional
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
@@ -35,9 +31,10 @@ class WanVideoPipeline(BasePipeline):
self.image_encoder: WanImageEncoder = None self.image_encoder: WanImageEncoder = None
self.dit: WanModel = None self.dit: WanModel = None
self.vae: WanVideoVAE = None self.vae: WanVideoVAE = None
self.model_names = ['text_encoder', 'dit', 'vae'] self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
self.height_division_factor = 16 self.height_division_factor = 16
self.width_division_factor = 16 self.width_division_factor = 16
self.use_unified_sequence_parallel = False
def enable_vram_management(self, num_persistent_param_in_dit=None): def enable_vram_management(self, num_persistent_param_in_dit=None):
@@ -153,6 +150,7 @@ class WanVideoPipeline(BasePipeline):
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit) pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
pipe.sp_size = get_sequence_parallel_world_size() pipe.sp_size = get_sequence_parallel_world_size()
pipe.use_unified_sequence_parallel = True
return pipe return pipe
@@ -202,6 +200,10 @@ class WanVideoPipeline(BasePipeline):
def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return frames return frames
def prepare_unified_sequence_parallel(self):
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
@torch.no_grad() @torch.no_grad()
@@ -271,6 +273,9 @@ class WanVideoPipeline(BasePipeline):
# TeaCache # TeaCache
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
# Unified Sequence Parallel
usp_kwargs = self.prepare_unified_sequence_parallel()
# Denoise # Denoise
self.load_models_to_device(["dit"]) self.load_models_to_device(["dit"])
@@ -278,9 +283,9 @@ class WanVideoPipeline(BasePipeline):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference # Inference
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi) noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
if cfg_scale != 1.0: if cfg_scale != 1.0:
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega) noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
noise_pred = noise_pred_posi noise_pred = noise_pred_posi
@@ -359,8 +364,15 @@ def model_fn_wan_video(
clip_feature: Optional[torch.Tensor] = None, clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None,
tea_cache: TeaCache = None, tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False,
**kwargs, **kwargs,
): ):
if use_unified_sequence_parallel:
import torch.distributed as dist
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context) context = dit.text_embedding(context)
@@ -388,15 +400,17 @@ def model_fn_wan_video(
x = tea_cache.update(x) x = tea_cache.update(x)
else: else:
# blocks # blocks
if dist.is_initialized() and dist.get_world_size() > 1: if use_unified_sequence_parallel:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] if dist.is_initialized() and dist.get_world_size() > 1:
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
for block in dit.blocks: for block in dit.blocks:
x = block(x, context, t_mod, freqs) x = block(x, context, t_mod, freqs)
if tea_cache is not None: if tea_cache is not None:
tea_cache.store(x) tea_cache.store(x)
x = dit.head(x, t) x = dit.head(x, t)
if dist.is_initialized() and dist.get_world_size() > 1: if use_unified_sequence_parallel:
x = get_sp_group().all_gather(x, dim=1) if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)
x = dit.unpatchify(x, (f, h, w)) x = dit.unpatchify(x, (f, h, w))
return x return x

View File

@@ -58,7 +58,7 @@ pip install xfuser>=0.4.3
``` ```
```bash ```bash
torchrun --standalone --nproc_per_node=8 ./wan_14b_text_to_video_usp.py torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
``` ```
2. Tensor Parallel 2. Tensor Parallel

View File

@@ -33,4 +33,4 @@ video = pipe(
num_inference_steps=50, num_inference_steps=50,
seed=0, tiled=True seed=0, tiled=True
) )
save_video(video, "video1.mp4", fps=25, quality=5) save_video(video, "video1.mp4", fps=25, quality=5)

View File

@@ -54,4 +54,5 @@ video = pipe(
num_inference_steps=50, num_inference_steps=50,
seed=0, tiled=True seed=0, tiled=True
) )
save_video(video, "video1.mp4", fps=25, quality=5) if dist.get_rank() == 0:
save_video(video, "video1.mp4", fps=25, quality=5)