diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index e6e279b..da1aafc 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -90,7 +90,6 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index a38282e..fdbcfc9 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -13,10 +13,6 @@ import numpy as np from PIL import Image from tqdm import tqdm from typing import Optional -import torch.distributed as dist -from xfuser.core.distributed import (get_sequence_parallel_rank, - get_sequence_parallel_world_size, - get_sp_group) from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm @@ -35,9 +31,10 @@ class WanVideoPipeline(BasePipeline): self.image_encoder: WanImageEncoder = None self.dit: WanModel = None self.vae: WanVideoVAE = None - self.model_names = ['text_encoder', 'dit', 'vae'] + self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder'] self.height_division_factor = 16 self.width_division_factor = 16 + self.use_unified_sequence_parallel = False def enable_vram_management(self, num_persistent_param_in_dit=None): @@ -153,6 +150,7 @@ class WanVideoPipeline(BasePipeline): block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit) pipe.sp_size = get_sequence_parallel_world_size() + pipe.use_unified_sequence_parallel = True return pipe @@ -202,6 +200,10 @@ class WanVideoPipeline(BasePipeline): def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return frames + + + def prepare_unified_sequence_parallel(self): + return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} @torch.no_grad() @@ -271,6 +273,9 @@ class WanVideoPipeline(BasePipeline): # TeaCache tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} + + # Unified Sequence Parallel + usp_kwargs = self.prepare_unified_sequence_parallel() # Denoise self.load_models_to_device(["dit"]) @@ -278,9 +283,9 @@ class WanVideoPipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference - noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi) + noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs) if cfg_scale != 1.0: - noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega) + noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi @@ -359,8 +364,15 @@ def model_fn_wan_video( clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, **kwargs, ): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) @@ -388,15 +400,17 @@ def model_fn_wan_video( x = tea_cache.update(x) else: # blocks - if dist.is_initialized() and dist.get_world_size() > 1: - x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] for block in dit.blocks: x = block(x, context, t_mod, freqs) if tea_cache is not None: tea_cache.store(x) x = dit.head(x, t) - if dist.is_initialized() and dist.get_world_size() > 1: - x = get_sp_group().all_gather(x, dim=1) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) x = dit.unpatchify(x, (f, h, w)) return x diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 1b8ac6c..f8b3e0b 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -58,7 +58,7 @@ pip install xfuser>=0.4.3 ``` ```bash -torchrun --standalone --nproc_per_node=8 ./wan_14b_text_to_video_usp.py +torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py ``` 2. Tensor Parallel diff --git a/examples/wanvideo/wan_14b_text_to_video.py b/examples/wanvideo/wan_14b_text_to_video.py index 2c4f15b..654565d 100644 --- a/examples/wanvideo/wan_14b_text_to_video.py +++ b/examples/wanvideo/wan_14b_text_to_video.py @@ -33,4 +33,4 @@ video = pipe( num_inference_steps=50, seed=0, tiled=True ) -save_video(video, "video1.mp4", fps=25, quality=5) \ No newline at end of file +save_video(video, "video1.mp4", fps=25, quality=5) diff --git a/examples/wanvideo/wan_14b_text_to_video_usp.py b/examples/wanvideo/wan_14b_text_to_video_usp.py index dcb2f29..8837294 100644 --- a/examples/wanvideo/wan_14b_text_to_video_usp.py +++ b/examples/wanvideo/wan_14b_text_to_video_usp.py @@ -54,4 +54,5 @@ video = pipe( num_inference_steps=50, seed=0, tiled=True ) -save_video(video, "video1.mp4", fps=25, quality=5) +if dist.get_rank() == 0: + save_video(video, "video1.mp4", fps=25, quality=5)