From 42cb7d96bbcf48bda9a7909d31dd767d8a30bd09 Mon Sep 17 00:00:00 2001 From: Jinzhe Pan Date: Mon, 17 Mar 2025 08:31:45 +0000 Subject: [PATCH 1/5] feat: sp for wan --- .../xdit_context_parallel.cpython-310.pyc | Bin 0 -> 3754 bytes .../distributed/xdit_context_parallel.py | 127 ++++++++++++++++++ diffsynth/models/wan_video_dit.py | 2 + diffsynth/pipelines/wan_video.py | 19 ++- examples/wanvideo/wan_14b_text_to_video.py | 38 ++++-- 5 files changed, 175 insertions(+), 11 deletions(-) create mode 100644 diffsynth/distributed/__pycache__/xdit_context_parallel.cpython-310.pyc create mode 100644 diffsynth/distributed/xdit_context_parallel.py diff --git a/diffsynth/distributed/__pycache__/xdit_context_parallel.cpython-310.pyc b/diffsynth/distributed/__pycache__/xdit_context_parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..844c90076f202adc0be23a9f0c4326ec2fff0bdb GIT binary patch literal 3754 zcmb7HTW=f372cV>aJeEWN|r54b{wGyTC-^6I!M!7;N}JuMGN)81==apLHpwN5A=b2DUkk@KJ~4@f@jrpzev>q`-7-_Xp?0%LDV25%XRV)PQTfi>+Rn2h?Lk_QqN1D7 ztoz#TMI}g9XuF70&9MYoRmV{+;_>sWh)EdQ$MB@cVxigbrFr;0>l;NAnfs_~@vi80CC`N@b0#kL@XhIdv*l@vGL>vMr{-*?2$Y5&<~|m9u&Z#ESG2bx5vwLh$J1^xjB=^Sv9yz=vM{nT z(o|Q7#n-EuiU%>wHiQ)wnd$~*q{!wHnd%zM)>L?LTl3;Ah&MX3%NA)Ry9G z6(jBixj~!PXo ztW!RPqf}mHZ$Qez-n6Lp2?zR|(?0E3!;S&8Q`@kx-wIa6nw2vneJz}!w*O?~-cGS| zGu;uLoj337Py=7B^nMD5kf57#8=m8s%vzi>*-`14rL^YZcv1d~b0M zzl_JSC}pl|XHKRYy(G@VzKqI=lA4|9ZQvc?FA_0eF6<3u?;y`&AYMAq-}w5mFzXYP zh#m=d`HOv;h6|=5V5GvT8lUhjYu2n9 zSdYP&oLQ@CRL!bIbuuZz3}17>2N!654WDQDT1!4~nl&!2H?W0Er=G3S9x(LQ2Am2z z*|RpVd=0l)6*b{Mw5DyGbG>5arBmWSd?+_gEto$^?ea&~@Ajr^s9jP2t~`{Q`X~4| zM^|TCbL%gg%tC8>bZvC4YFBHg+Nm|VQIRy4`gt8=My7q6I4X)*EBnS1AA7I~ z1Dp$={0=qAc6x-U%ZI`mdFUmSq7Cv#t4|y7hyK}e_I?lD=MPP$Nt-AIE%ztb+@CG% z%l3^AA*OqAnor6CsNjt=F-hbFL*xLy+?#%rMg4JfcNFETy78J1%Xb3IsC!y|96%> z+`37=t0bHn*iv?Obc?{?WiT~iQ~LxtxaC-{n&Dh!BiS<%zCKBp*k)pXi?X4Y^Y!Ww zuo;h|feej+4WOF1r)%+eK%5lN%HA;UpQzioigpM(wI`4BC>3g#L}2jN-*Y?%hb zcqROkXsDR*|5sGt;Z)$ebG(GcGoUmud(j3`A^_8fTT=2w2=1NQNDo|u>SYU|zp{ag z7&B>vjW+Q9yiMcog|~+_pu{<+TB+joKcaukovnw`qhcH z59EndPt!|%Qcj_NI1!z2QfC93a&XqI(6^&4tZ1Yany zPqiVDn21r|OC7dJhlu^@5+D*V6+0kp)`sd@yoE2l4qxP9oOw?JFW~9`Ovotr_x3vL zi#)_bb`A?yggtPOIL-3H98qr`_vs(WZZA`E_d@Qx$ViO082i3)kuSmjm#|kJ)A<(v zUX-~irOGCYF@n(EuS`;^oqJi97~?QUyKvK&*Y)3Bfae{8+=AT02R1d>25ZnflEJkZ2rtAdv*C2+~ZGpE%?0g!G8f^oyg<> literal 0 HcmV?d00001 diff --git a/diffsynth/distributed/xdit_context_parallel.py b/diffsynth/distributed/xdit_context_parallel.py new file mode 100644 index 0000000..a144e08 --- /dev/null +++ b/diffsynth/distributed/xdit_context_parallel.py @@ -0,0 +1,127 @@ +import torch +from typing import Optional +from einops import rearrange +from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) +from xfuser.core.long_ctx_attention import xFuserLongContextAttention + +def sinusoidal_embedding_1d(dim, position): + sinusoid = torch.outer(position.type(torch.float64), torch.pow( + 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x.to(position.dtype) + +def pad_freqs(original_tensor, target_len): + seq_len, s1, s2 = original_tensor.shape + pad_size = target_len - seq_len + padding_tensor = torch.ones( + pad_size, + s1, + s2, + dtype=original_tensor.dtype, + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) + return padded_tensor + +def rope_apply(x, freqs, num_heads): + x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) + s_per_rank = x.shape[1] + + x_out = torch.view_as_complex(x.to(torch.float64).reshape( + x.shape[0], x.shape[1], x.shape[2], -1, 2)) + + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + freqs = pad_freqs(freqs, s_per_rank * sp_size) + freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] + + x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) + return x_out.to(x.dtype) + +def usp_dit_forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + # Context Parallel + x = torch.chunk( + x, get_sequence_parallel_world_size(), + dim=1)[get_sequence_parallel_rank()] + + for block in self.blocks: + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + + x = self.head(x, t) + + # Context Parallel + x = get_sp_group().all_gather(x, dim=1) + + # unpatchify + x = self.unpatchify(x, (f, h, w)) + return x + + +def usp_attn_forward(self, x, freqs): + q = self.norm_q(self.q(x)) + k = self.norm_k(self.k(x)) + v = self.v(x) + + q = rope_apply(q, freqs, self.num_heads) + k = rope_apply(k, freqs, self.num_heads) + q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads) + k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) + v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) + + x = xFuserLongContextAttention()( + None, + query=q, + key=k, + value=v, + ) + x = x.flatten(2) + + return self.o(x) \ No newline at end of file diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 32a79e3..618e7d8 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -90,6 +90,8 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) + + print(f"x_out.shape: {x_out.shape}, freqs.shape: {freqs.shape}") x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 76e1fa0..b79352f 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -1,3 +1,4 @@ +import types from ..models import ModelManager from ..models.wan_video_dit import WanModel from ..models.wan_video_text_encoder import WanTextEncoder @@ -12,6 +13,10 @@ import numpy as np from PIL import Image from tqdm import tqdm from typing import Optional +import torch.distributed as dist +from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm @@ -135,11 +140,19 @@ class WanVideoPipeline(BasePipeline): @staticmethod - def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None): + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False): if device is None: device = model_manager.device if torch_dtype is None: torch_dtype = model_manager.torch_dtype pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) pipe.fetch_models(model_manager) + if use_usp: + from xfuser.core.distributed import get_sequence_parallel_world_size + from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward + + for block in pipe.dit.blocks: + block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) + pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit) + pipe.sp_size = get_sequence_parallel_world_size() return pipe @@ -375,11 +388,15 @@ def model_fn_wan_video( x = tea_cache.update(x) else: # blocks + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] for block in dit.blocks: x = block(x, context, t_mod, freqs) if tea_cache is not None: tea_cache.store(x) x = dit.head(x, t) + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) x = dit.unpatchify(x, (f, h, w)) return x diff --git a/examples/wanvideo/wan_14b_text_to_video.py b/examples/wanvideo/wan_14b_text_to_video.py index 654565d..d67e1d5 100644 --- a/examples/wanvideo/wan_14b_text_to_video.py +++ b/examples/wanvideo/wan_14b_text_to_video.py @@ -1,29 +1,47 @@ import torch from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData from modelscope import snapshot_download +import torch.distributed as dist # Download models -snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") +# snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") # Load models model_manager = ModelManager(device="cpu") model_manager.load_models( [ [ - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", ], - "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", + "/demo-huabei2/models/dit/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "/demo-huabei2/models/dit/Wan2.1-T2V-14B/Wan2.1_VAE.pth", ], torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. ) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") + +dist.init_process_group( + backend="nccl", + init_method="env://", +) +from xfuser.core.distributed import (initialize_model_parallel, + init_distributed_environment) +init_distributed_environment( + rank=dist.get_rank(), world_size=dist.get_world_size()) + +initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=1, + ulysses_degree=dist.get_world_size(), +) +torch.cuda.set_device(dist.get_rank()) + +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device=f"cuda:{dist.get_rank()}", use_usp=True if dist.get_world_size() > 1 else False) pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. # Text-to-video From 1e58e6ef8205c93addd052b68a21b745a6d2ae49 Mon Sep 17 00:00:00 2001 From: feifeibear Date: Mon, 17 Mar 2025 09:00:52 +0000 Subject: [PATCH 2/5] fix some bugs --- diffsynth/distributed/__init__.py | 0 .../xdit_context_parallel.cpython-310.pyc | Bin 3754 -> 0 bytes requirements.txt | 1 + 3 files changed, 1 insertion(+) create mode 100644 diffsynth/distributed/__init__.py delete mode 100644 diffsynth/distributed/__pycache__/xdit_context_parallel.cpython-310.pyc diff --git a/diffsynth/distributed/__init__.py b/diffsynth/distributed/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffsynth/distributed/__pycache__/xdit_context_parallel.cpython-310.pyc b/diffsynth/distributed/__pycache__/xdit_context_parallel.cpython-310.pyc deleted file mode 100644 index 844c90076f202adc0be23a9f0c4326ec2fff0bdb..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3754 zcmb7HTW=f372cV>aJeEWN|r54b{wGyTC-^6I!M!7;N}JuMGN)81==apLHpwN5A=b2DUkk@KJ~4@f@jrpzev>q`-7-_Xp?0%LDV25%XRV)PQTfi>+Rn2h?Lk_QqN1D7 ztoz#TMI}g9XuF70&9MYoRmV{+;_>sWh)EdQ$MB@cVxigbrFr;0>l;NAnfs_~@vi80CC`N@b0#kL@XhIdv*l@vGL>vMr{-*?2$Y5&<~|m9u&Z#ESG2bx5vwLh$J1^xjB=^Sv9yz=vM{nT z(o|Q7#n-EuiU%>wHiQ)wnd$~*q{!wHnd%zM)>L?LTl3;Ah&MX3%NA)Ry9G z6(jBixj~!PXo ztW!RPqf}mHZ$Qez-n6Lp2?zR|(?0E3!;S&8Q`@kx-wIa6nw2vneJz}!w*O?~-cGS| zGu;uLoj337Py=7B^nMD5kf57#8=m8s%vzi>*-`14rL^YZcv1d~b0M zzl_JSC}pl|XHKRYy(G@VzKqI=lA4|9ZQvc?FA_0eF6<3u?;y`&AYMAq-}w5mFzXYP zh#m=d`HOv;h6|=5V5GvT8lUhjYu2n9 zSdYP&oLQ@CRL!bIbuuZz3}17>2N!654WDQDT1!4~nl&!2H?W0Er=G3S9x(LQ2Am2z z*|RpVd=0l)6*b{Mw5DyGbG>5arBmWSd?+_gEto$^?ea&~@Ajr^s9jP2t~`{Q`X~4| zM^|TCbL%gg%tC8>bZvC4YFBHg+Nm|VQIRy4`gt8=My7q6I4X)*EBnS1AA7I~ z1Dp$={0=qAc6x-U%ZI`mdFUmSq7Cv#t4|y7hyK}e_I?lD=MPP$Nt-AIE%ztb+@CG% z%l3^AA*OqAnor6CsNjt=F-hbFL*xLy+?#%rMg4JfcNFETy78J1%Xb3IsC!y|96%> z+`37=t0bHn*iv?Obc?{?WiT~iQ~LxtxaC-{n&Dh!BiS<%zCKBp*k)pXi?X4Y^Y!Ww zuo;h|feej+4WOF1r)%+eK%5lN%HA;UpQzioigpM(wI`4BC>3g#L}2jN-*Y?%hb zcqROkXsDR*|5sGt;Z)$ebG(GcGoUmud(j3`A^_8fTT=2w2=1NQNDo|u>SYU|zp{ag z7&B>vjW+Q9yiMcog|~+_pu{<+TB+joKcaukovnw`qhcH z59EndPt!|%Qcj_NI1!z2QfC93a&XqI(6^&4tZ1Yany zPqiVDn21r|OC7dJhlu^@5+D*V6+0kp)`sd@yoE2l4qxP9oOw?JFW~9`Ovotr_x3vL zi#)_bb`A?yggtPOIL-3H98qr`_vs(WZZA`E_d@Qx$ViO082i3)kuSmjm#|kJ)A<(v zUX-~irOGCYF@n(EuS`;^oqJi97~?QUyKvK&*Y)3Bfae{8+=AT02R1d>25ZnflEJkZ2rtAdv*C2+~ZGpE%?0g!G8f^oyg<> diff --git a/requirements.txt b/requirements.txt index 63a871b..7dc3846 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ sentencepiece protobuf modelscope ftfy +xfuser>=0.4.2 From d8b250607aac7764a5230c208b593ea5ae020c32 Mon Sep 17 00:00:00 2001 From: feifeibear Date: Mon, 17 Mar 2025 09:04:51 +0000 Subject: [PATCH 3/5] polish code --- diffsynth/models/wan_video_dit.py | 1 - examples/wanvideo/wan_14b_text_to_video.py | 23 ++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 618e7d8..a2c55e1 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -91,7 +91,6 @@ def rope_apply(x, freqs, num_heads): x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - print(f"x_out.shape: {x_out.shape}, freqs.shape: {freqs.shape}") x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) diff --git a/examples/wanvideo/wan_14b_text_to_video.py b/examples/wanvideo/wan_14b_text_to_video.py index d67e1d5..dcb2f29 100644 --- a/examples/wanvideo/wan_14b_text_to_video.py +++ b/examples/wanvideo/wan_14b_text_to_video.py @@ -5,22 +5,22 @@ import torch.distributed as dist # Download models -# snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") +snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") # Load models model_manager = ModelManager(device="cpu") model_manager.load_models( [ [ - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", ], - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", - "/demo-huabei2/models/dit/Wan2.1-T2V-14B/Wan2.1_VAE.pth", + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", ], torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. ) @@ -41,7 +41,10 @@ initialize_model_parallel( ) torch.cuda.set_device(dist.get_rank()) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device=f"cuda:{dist.get_rank()}", use_usp=True if dist.get_world_size() > 1 else False) +pipe = WanVideoPipeline.from_model_manager(model_manager, + torch_dtype=torch.bfloat16, + device=f"cuda:{dist.get_rank()}", + use_usp=True if dist.get_world_size() > 1 else False) pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. # Text-to-video From d0fed6ba72aac622ec9d9066045307f945740314 Mon Sep 17 00:00:00 2001 From: ByteDance Date: Tue, 25 Mar 2025 11:51:37 +0800 Subject: [PATCH 4/5] add usp for wanx --- examples/wanvideo/README.md | 14 +++++ examples/wanvideo/wan_14b_text_to_video.py | 25 +------- .../wanvideo/wan_14b_text_to_video_usp.py | 57 +++++++++++++++++++ requirements.txt | 1 - 4 files changed, 73 insertions(+), 24 deletions(-) create mode 100644 examples/wanvideo/wan_14b_text_to_video_usp.py diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index b3f5ade..1b8ac6c 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -49,6 +49,20 @@ We present a detailed table here. The model is tested on a single A100. https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f +### Parallel Inference + +1. Unified Sequence Parallel (USP) + +```bash +pip install xfuser>=0.4.3 +``` + +```bash +torchrun --standalone --nproc_per_node=8 ./wan_14b_text_to_video_usp.py +``` + +2. Tensor Parallel + Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py). ### Wan-Video-14B-I2V diff --git a/examples/wanvideo/wan_14b_text_to_video.py b/examples/wanvideo/wan_14b_text_to_video.py index dcb2f29..2c4f15b 100644 --- a/examples/wanvideo/wan_14b_text_to_video.py +++ b/examples/wanvideo/wan_14b_text_to_video.py @@ -1,7 +1,6 @@ import torch from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData from modelscope import snapshot_download -import torch.distributed as dist # Download models @@ -24,27 +23,7 @@ model_manager.load_models( ], torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. ) - -dist.init_process_group( - backend="nccl", - init_method="env://", -) -from xfuser.core.distributed import (initialize_model_parallel, - init_distributed_environment) -init_distributed_environment( - rank=dist.get_rank(), world_size=dist.get_world_size()) - -initialize_model_parallel( - sequence_parallel_degree=dist.get_world_size(), - ring_degree=1, - ulysses_degree=dist.get_world_size(), -) -torch.cuda.set_device(dist.get_rank()) - -pipe = WanVideoPipeline.from_model_manager(model_manager, - torch_dtype=torch.bfloat16, - device=f"cuda:{dist.get_rank()}", - use_usp=True if dist.get_world_size() > 1 else False) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. # Text-to-video @@ -54,4 +33,4 @@ video = pipe( num_inference_steps=50, seed=0, tiled=True ) -save_video(video, "video1.mp4", fps=25, quality=5) +save_video(video, "video1.mp4", fps=25, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/wan_14b_text_to_video_usp.py b/examples/wanvideo/wan_14b_text_to_video_usp.py new file mode 100644 index 0000000..dcb2f29 --- /dev/null +++ b/examples/wanvideo/wan_14b_text_to_video_usp.py @@ -0,0 +1,57 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download +import torch.distributed as dist + + +# Download models +snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + [ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + ], + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", + ], + torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. +) + +dist.init_process_group( + backend="nccl", + init_method="env://", +) +from xfuser.core.distributed import (initialize_model_parallel, + init_distributed_environment) +init_distributed_environment( + rank=dist.get_rank(), world_size=dist.get_world_size()) + +initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=1, + ulysses_degree=dist.get_world_size(), +) +torch.cuda.set_device(dist.get_rank()) + +pipe = WanVideoPipeline.from_model_manager(model_manager, + torch_dtype=torch.bfloat16, + device=f"cuda:{dist.get_rank()}", + use_usp=True if dist.get_world_size() > 1 else False) +pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. + +# Text-to-video +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + seed=0, tiled=True +) +save_video(video, "video1.mp4", fps=25, quality=5) diff --git a/requirements.txt b/requirements.txt index 7dc3846..63a871b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,3 @@ sentencepiece protobuf modelscope ftfy -xfuser>=0.4.2 From 4e43d4d4613271e292cb6b4935f303a2d9ca337c Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 25 Mar 2025 19:26:24 +0800 Subject: [PATCH 5/5] fix usp dependency --- diffsynth/models/wan_video_dit.py | 1 - diffsynth/pipelines/wan_video.py | 36 +++++++++++++------ examples/wanvideo/README.md | 2 +- examples/wanvideo/wan_14b_text_to_video.py | 2 +- .../wanvideo/wan_14b_text_to_video_usp.py | 3 +- 5 files changed, 29 insertions(+), 15 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index e6e279b..da1aafc 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -90,7 +90,6 @@ def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) - x_out = torch.view_as_real(x_out * freqs).flatten(2) return x_out.to(x.dtype) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index a38282e..fdbcfc9 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -13,10 +13,6 @@ import numpy as np from PIL import Image from tqdm import tqdm from typing import Optional -import torch.distributed as dist -from xfuser.core.distributed import (get_sequence_parallel_rank, - get_sequence_parallel_world_size, - get_sp_group) from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm @@ -35,9 +31,10 @@ class WanVideoPipeline(BasePipeline): self.image_encoder: WanImageEncoder = None self.dit: WanModel = None self.vae: WanVideoVAE = None - self.model_names = ['text_encoder', 'dit', 'vae'] + self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder'] self.height_division_factor = 16 self.width_division_factor = 16 + self.use_unified_sequence_parallel = False def enable_vram_management(self, num_persistent_param_in_dit=None): @@ -153,6 +150,7 @@ class WanVideoPipeline(BasePipeline): block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn) pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit) pipe.sp_size = get_sequence_parallel_world_size() + pipe.use_unified_sequence_parallel = True return pipe @@ -202,6 +200,10 @@ class WanVideoPipeline(BasePipeline): def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return frames + + + def prepare_unified_sequence_parallel(self): + return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel} @torch.no_grad() @@ -271,6 +273,9 @@ class WanVideoPipeline(BasePipeline): # TeaCache tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} + + # Unified Sequence Parallel + usp_kwargs = self.prepare_unified_sequence_parallel() # Denoise self.load_models_to_device(["dit"]) @@ -278,9 +283,9 @@ class WanVideoPipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference - noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi) + noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs) if cfg_scale != 1.0: - noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega) + noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi @@ -359,8 +364,15 @@ def model_fn_wan_video( clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, + use_unified_sequence_parallel: bool = False, **kwargs, ): + if use_unified_sequence_parallel: + import torch.distributed as dist + from xfuser.core.distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group) + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) @@ -388,15 +400,17 @@ def model_fn_wan_video( x = tea_cache.update(x) else: # blocks - if dist.is_initialized() and dist.get_world_size() > 1: - x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] for block in dit.blocks: x = block(x, context, t_mod, freqs) if tea_cache is not None: tea_cache.store(x) x = dit.head(x, t) - if dist.is_initialized() and dist.get_world_size() > 1: - x = get_sp_group().all_gather(x, dim=1) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + x = get_sp_group().all_gather(x, dim=1) x = dit.unpatchify(x, (f, h, w)) return x diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 1b8ac6c..f8b3e0b 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -58,7 +58,7 @@ pip install xfuser>=0.4.3 ``` ```bash -torchrun --standalone --nproc_per_node=8 ./wan_14b_text_to_video_usp.py +torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py ``` 2. Tensor Parallel diff --git a/examples/wanvideo/wan_14b_text_to_video.py b/examples/wanvideo/wan_14b_text_to_video.py index 2c4f15b..654565d 100644 --- a/examples/wanvideo/wan_14b_text_to_video.py +++ b/examples/wanvideo/wan_14b_text_to_video.py @@ -33,4 +33,4 @@ video = pipe( num_inference_steps=50, seed=0, tiled=True ) -save_video(video, "video1.mp4", fps=25, quality=5) \ No newline at end of file +save_video(video, "video1.mp4", fps=25, quality=5) diff --git a/examples/wanvideo/wan_14b_text_to_video_usp.py b/examples/wanvideo/wan_14b_text_to_video_usp.py index dcb2f29..8837294 100644 --- a/examples/wanvideo/wan_14b_text_to_video_usp.py +++ b/examples/wanvideo/wan_14b_text_to_video_usp.py @@ -54,4 +54,5 @@ video = pipe( num_inference_steps=50, seed=0, tiled=True ) -save_video(video, "video1.mp4", fps=25, quality=5) +if dist.get_rank() == 0: + save_video(video, "video1.mp4", fps=25, quality=5)