diff --git a/diffsynth/distributed/__init__.py b/diffsynth/distributed/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py new file mode 100644 index 0000000..a144e08 --- /dev/null +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -0,0 +1,127 @@ +import torch +from typing import Optional +from einops import rearrange +from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) +from xfuser.core.long_ctx_attention import xFuserLongContextAttention + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=original_tensor.dtype, + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + s_per_rank = x.shape[1] + + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs = pad_freqs(freqs, s_per_rank * sp_size) + freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + + x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) + return x_out.to(x.dtype) + +def usp_dit_forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + # Context Parallel + x = torch.chunk( + x, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + + for block in self.blocks: + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + + # Context Parallel + x = get_sp_group().all_gather(x, dim=1) + + # unpatchify + x = self.unpatchify(x, (f, h, w)) + return x + + +def usp_attn_forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) + + x = xFuserLongContextAttention()( + None, + query=q, + key=k, + value=v, + ) + x = x.flatten(2) + + return self.o(x) \ No newline at end of file diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 439d311..fdbcfc9 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -1,3 +1,4 @@ +import types from ..models import ModelManager from ..models.wan_video_dit import WanModel from ..models.wan_video_text_encoder import WanTextEncoder @@ -30,9 +31,10 @@ class WanVideoPipeline(BasePipeline): self.image_encoder: WanImageEncoder = None self.dit: WanModel = None self.vae: WanVideoVAE = None - self.model_names = ['text_encoder', 'dit', 'vae'] + self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder'] self.height_division_factor = 16 self.width_division_factor = 16 + self.use_unified_sequence_parallel = False def enable_vram_management(self, num_persistent_param_in_dit=None): @@ -135,11 +137,20 @@ class WanVideoPipeline(BasePipeline): @staticmethod - def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None): + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): if device is None: device = model_manager.device if torch_dtype is None: torch_dtype = model_manager.torch_dtype pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) pipe.fetch_models(model_manager) + if use_usp: + from xfuser.core.distributed import get_sequence_parallel_world_size + from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward + + for block in pipe.dit.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit) + pipe.sp_size = get_sequence_parallel_world_size() + pipe.use_unified_sequence_parallel = True return pipe @@ -189,6 +200,10 @@ class WanVideoPipeline(BasePipeline): def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return frames + + + def prepare_unified_sequence_parallel(self): + return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} @torch.no_grad() @@ -258,6 +273,9 @@ class WanVideoPipeline(BasePipeline): # TeaCache tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} + + # Unified Sequence Parallel + usp_kwargs = self.prepare_unified_sequence_parallel() # Denoise self.load_models_to_device(["dit"]) @@ -265,9 +283,9 @@ class WanVideoPipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference - noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi) + noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs) if cfg_scale != 1.0: - noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega) + noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi @@ -346,8 +364,15 @@ def model_fn_wan_video( clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, **kwargs, ): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) @@ -375,11 +400,17 @@ def model_fn_wan_video( x = tea_cache.update(x) else: # blocks + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] for block in dit.blocks: x = block(x, context, t_mod, freqs) if tea_cache is not None: tea_cache.store(x) x = dit.head(x, t) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) x = dit.unpatchify(x, (f, h, w)) return x diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index b3f5ade..f8b3e0b 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -49,6 +49,20 @@ We present a detailed table here. The model is tested on a single A100. https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f +### Parallel Inference + +1. Unified Sequence Parallel (USP) + +```bash +pip install xfuser>=0.4.3 +``` + +```bash +torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py +``` + +2. Tensor Parallel + Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py). ### Wan-Video-14B-I2V diff --git a/examples/wanvideo/wan_14b_text_to_video_usp.py b/examples/wanvideo/wan_14b_text_to_video_usp.py new file mode 100644 index 0000000..8837294 --- /dev/null +++ b/examples/wanvideo/wan_14b_text_to_video_usp.py @@ -0,0 +1,58 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download +import torch.distributed as dist + + +# Download models +snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + [ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + ], + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", + ], + torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. +) + +dist.init_process_group( + backend="nccl", + init_method="env://", +) +from xfuser.core.distributed import (initialize_model_parallel, + init_distributed_environment) +init_distributed_environment( + rank=dist.get_rank(), world_size=dist.get_world_size()) + +initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=1, + ulysses_degree=dist.get_world_size(), +) +torch.cuda.set_device(dist.get_rank()) + +pipe = WanVideoPipeline.from_model_manager(model_manager, + torch_dtype=torch.bfloat16, + device=f"cuda:{dist.get_rank()}", + use_usp=True if dist.get_world_size() > 1 else False) +pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. + +# Text-to-video +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + seed=0, tiled=True +) +if dist.get_rank() == 0: + save_video(video, "video1.mp4", fps=25, quality=5)