mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
fix usp dependency
This commit is contained in:
@@ -90,7 +90,6 @@ def rope_apply(x, freqs, num_heads):
|
|||||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
|
|
||||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -13,10 +13,6 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch.distributed as dist
|
|
||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
|
||||||
get_sequence_parallel_world_size,
|
|
||||||
get_sp_group)
|
|
||||||
|
|
||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||||
@@ -35,9 +31,10 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.image_encoder: WanImageEncoder = None
|
self.image_encoder: WanImageEncoder = None
|
||||||
self.dit: WanModel = None
|
self.dit: WanModel = None
|
||||||
self.vae: WanVideoVAE = None
|
self.vae: WanVideoVAE = None
|
||||||
self.model_names = ['text_encoder', 'dit', 'vae']
|
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
|
||||||
self.height_division_factor = 16
|
self.height_division_factor = 16
|
||||||
self.width_division_factor = 16
|
self.width_division_factor = 16
|
||||||
|
self.use_unified_sequence_parallel = False
|
||||||
|
|
||||||
|
|
||||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||||
@@ -153,6 +150,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||||
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
|
pipe.dit.forward = types.MethodType(usp_dit_forward, pipe.dit)
|
||||||
pipe.sp_size = get_sequence_parallel_world_size()
|
pipe.sp_size = get_sequence_parallel_world_size()
|
||||||
|
pipe.use_unified_sequence_parallel = True
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@@ -204,6 +202,10 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_unified_sequence_parallel(self):
|
||||||
|
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -272,15 +274,18 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||||
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||||
|
|
||||||
|
# Unified Sequence Parallel
|
||||||
|
usp_kwargs = self.prepare_unified_sequence_parallel()
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(["dit"])
|
self.load_models_to_device(["dit"])
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi)
|
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega)
|
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
@@ -359,8 +364,15 @@ def model_fn_wan_video(
|
|||||||
clip_feature: Optional[torch.Tensor] = None,
|
clip_feature: Optional[torch.Tensor] = None,
|
||||||
y: Optional[torch.Tensor] = None,
|
y: Optional[torch.Tensor] = None,
|
||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
|
use_unified_sequence_parallel: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if use_unified_sequence_parallel:
|
||||||
|
import torch.distributed as dist
|
||||||
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
|
get_sequence_parallel_world_size,
|
||||||
|
get_sp_group)
|
||||||
|
|
||||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||||
context = dit.text_embedding(context)
|
context = dit.text_embedding(context)
|
||||||
@@ -388,15 +400,17 @@ def model_fn_wan_video(
|
|||||||
x = tea_cache.update(x)
|
x = tea_cache.update(x)
|
||||||
else:
|
else:
|
||||||
# blocks
|
# blocks
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if use_unified_sequence_parallel:
|
||||||
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
for block in dit.blocks:
|
for block in dit.blocks:
|
||||||
x = block(x, context, t_mod, freqs)
|
x = block(x, context, t_mod, freqs)
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(x)
|
tea_cache.store(x)
|
||||||
|
|
||||||
x = dit.head(x, t)
|
x = dit.head(x, t)
|
||||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
if use_unified_sequence_parallel:
|
||||||
x = get_sp_group().all_gather(x, dim=1)
|
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
|
x = get_sp_group().all_gather(x, dim=1)
|
||||||
x = dit.unpatchify(x, (f, h, w))
|
x = dit.unpatchify(x, (f, h, w))
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ pip install xfuser>=0.4.3
|
|||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
torchrun --standalone --nproc_per_node=8 ./wan_14b_text_to_video_usp.py
|
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Tensor Parallel
|
2. Tensor Parallel
|
||||||
|
|||||||
@@ -54,4 +54,5 @@ video = pipe(
|
|||||||
num_inference_steps=50,
|
num_inference_steps=50,
|
||||||
seed=0, tiled=True
|
seed=0, tiled=True
|
||||||
)
|
)
|
||||||
save_video(video, "video1.mp4", fps=25, quality=5)
|
if dist.get_rank() == 0:
|
||||||
|
save_video(video, "video1.mp4", fps=25, quality=5)
|
||||||
|
|||||||
Reference in New Issue
Block a user