mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
Merge pull request #485 from modelscope/usp
support Unified Sequence Parallel
This commit is contained in:
0
diffsynth/distributed/__init__.py
Normal file
0
diffsynth/distributed/__init__.py
Normal file
127
diffsynth/distributed/xdit_context_parallel.py
Normal file
127
diffsynth/distributed/xdit_context_parallel.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Optional
|
||||||
|
from einops import rearrange
|
||||||
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
|
get_sequence_parallel_world_size,
|
||||||
|
get_sp_group)
|
||||||
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||||
|
|
||||||
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
|
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
||||||
|
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
||||||
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
||||||
|
return x.to(position.dtype)
|
||||||
|
|
||||||
|
def pad_freqs(original_tensor, target_len):
|
||||||
|
seq_len, s1, s2 = original_tensor.shape
|
||||||
|
pad_size = target_len - seq_len
|
||||||
|
padding_tensor = torch.ones(
|
||||||
|
pad_size,
|
||||||
|
s1,
|
||||||
|
s2,
|
||||||
|
dtype=original_tensor.dtype,
|
||||||
|
device=original_tensor.device)
|
||||||
|
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
|
def rope_apply(x, freqs, num_heads):
|
||||||
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
|
s_per_rank = x.shape[1]
|
||||||
|
|
||||||
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
|
|
||||||
|
sp_size = get_sequence_parallel_world_size()
|
||||||
|
sp_rank = get_sequence_parallel_rank()
|
||||||
|
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||||
|
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||||
|
|
||||||
|
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||||
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
def usp_dit_forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
context: torch.Tensor,
|
||||||
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
use_gradient_checkpointing: bool = False,
|
||||||
|
use_gradient_checkpointing_offload: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
t = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, timestep))
|
||||||
|
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
if self.has_image_input:
|
||||||
|
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||||
|
clip_embdding = self.img_emb(clip_feature)
|
||||||
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
|
x, (f, h, w) = self.patchify(x)
|
||||||
|
|
||||||
|
freqs = torch.cat([
|
||||||
|
self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||||
|
self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||||
|
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
# Context Parallel
|
||||||
|
x = torch.chunk(
|
||||||
|
x, get_sequence_parallel_world_size(),
|
||||||
|
dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
if self.training and use_gradient_checkpointing:
|
||||||
|
if use_gradient_checkpointing_offload:
|
||||||
|
with torch.autograd.graph.save_on_cpu():
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
x, context, t_mod, freqs,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = block(x, context, t_mod, freqs)
|
||||||
|
|
||||||
|
x = self.head(x, t)
|
||||||
|
|
||||||
|
# Context Parallel
|
||||||
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, (f, h, w))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def usp_attn_forward(self, x, freqs):
|
||||||
|
q = self.norm_q(self.q(x))
|
||||||
|
k = self.norm_k(self.k(x))
|
||||||
|
v = self.v(x)
|
||||||
|
|
||||||
|
q = rope_apply(q, freqs, self.num_heads)
|
||||||
|
k = rope_apply(k, freqs, self.num_heads)
|
||||||
|
q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
|
||||||
|
x = xFuserLongContextAttention()(
|
||||||
|
None,
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
)
|
||||||
|
x = x.flatten(2)
|
||||||
|
|
||||||
|
return self.o(x)
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import types
|
||||||
from ..models import ModelManager
|
from ..models import ModelManager
|
||||||
from ..models.wan_video_dit import WanModel
|
from ..models.wan_video_dit import WanModel
|
||||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||||
@@ -30,9 +31,10 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.image_encoder: WanImageEncoder = None
|
self.image_encoder: WanImageEncoder = None
|
||||||
self.dit: WanModel = None
|
self.dit: WanModel = None
|
||||||
self.vae: WanVideoVAE = None
|
self.vae: WanVideoVAE = None
|
||||||
self.model_names = ['text_encoder', 'dit', 'vae']
|
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
|
||||||
self.height_division_factor = 16
|
self.height_division_factor = 16
|
||||||
self.width_division_factor = 16
|
self.width_division_factor = 16
|
||||||
|
self.use_unified_sequence_parallel = False
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||||
@@ -135,11 +137,20 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, use_usp=False):
|
||||||
if device is None: device = model_manager.device
|
if device is None: device = model_manager.device
|
||||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
pipe.fetch_models(model_manager)
|
pipe.fetch_models(model_manager)
|
||||||
|
if use_usp:
|
||||||
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
|
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
|
||||||
|
|
||||||
|
for block in pipe.dit.blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
|
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
|
||||||
|
pipe.sp_size = get_sequence_parallel_world_size()
|
||||||
|
pipe.use_unified_sequence_parallel = True
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@@ -191,6 +202,10 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_unified_sequence_parallel(self):
|
||||||
|
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -259,15 +274,18 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||||
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||||
|
|
||||||
|
# Unified Sequence Parallel
|
||||||
|
usp_kwargs = self.prepare_unified_sequence_parallel()
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(["dit"])
|
self.load_models_to_device(["dit"])
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi)
|
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega)
|
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
@@ -346,8 +364,15 @@ def model_fn_wan_video(
|
|||||||
clip_feature: Optional[torch.Tensor] = None,
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
|
use_unified_sequence_parallel: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
import torch.distributed as dist
|
||||||
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
|
get_sequence_parallel_world_size,
|
||||||
|
get_sp_group)
|
||||||
|
|
||||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||||
context = dit.text_embedding(context)
|
context = dit.text_embedding(context)
|
||||||
@@ -375,11 +400,17 @@ def model_fn_wan_video(
|
|||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
# blocks
|
# blocks
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
for block in dit.blocks:
|
for block in dit.blocks:
|
||||||
x = block(x, context, t_mod, freqs)
|
x = block(x, context, t_mod, freqs)
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(x)
|
tea_cache.store(x)
|
||||||
|
|
||||||
x = dit.head(x, t)
|
x = dit.head(x, t)
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
x = dit.unpatchify(x, (f, h, w))
|
x = dit.unpatchify(x, (f, h, w))
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -49,6 +49,20 @@ We present a detailed table here. The model is tested on a single A100.
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
|
||||||
|
|
||||||
|
### Parallel Inference
|
||||||
|
|
||||||
|
1. Unified Sequence Parallel (USP)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install xfuser>=0.4.3
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Tensor Parallel
|
||||||
|
|
||||||
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
|
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
|
||||||
|
|
||||||
### Wan-Video-14B-I2V
|
### Wan-Video-14B-I2V
|
||||||
|
|||||||
58
examples/wanvideo/wan_14b_text_to_video_usp.py
Normal file
58
examples/wanvideo/wan_14b_text_to_video_usp.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
# Download models
|
||||||
|
snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B")
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(device="cpu")
|
||||||
|
model_manager.load_models(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||||
|
],
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||||
|
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth",
|
||||||
|
],
|
||||||
|
torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
||||||
|
)
|
||||||
|
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
init_method="env://",
|
||||||
|
)
|
||||||
|
from xfuser.core.distributed import (initialize_model_parallel,
|
||||||
|
init_distributed_environment)
|
||||||
|
init_distributed_environment(
|
||||||
|
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||||
|
|
||||||
|
initialize_model_parallel(
|
||||||
|
sequence_parallel_degree=dist.get_world_size(),
|
||||||
|
ring_degree=1,
|
||||||
|
ulysses_degree=dist.get_world_size(),
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(dist.get_rank())
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_model_manager(model_manager,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device=f"cuda:{dist.get_rank()}",
|
||||||
|
use_usp=True if dist.get_world_size() > 1 else False)
|
||||||
|
pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
|
||||||
|
|
||||||
|
# Text-to-video
|
||||||
|
video = pipe(
|
||||||
|
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||||
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
num_inference_steps=50,
|
||||||
|
seed=0, tiled=True
|
||||||
|
)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
save_video(video, "video1.mp4", fps=25, quality=5)
|
||||||
Reference in New Issue
Block a user