accelerate

This commit is contained in:
Artiprocher
2025-06-12 10:37:33 +08:00
parent 6a833c7134
commit b25c66b303
5 changed files with 90 additions and 80 deletions

View File

@@ -212,6 +212,7 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_FunReference(),
WanVideoUnit_SpeedControl(),
WanVideoUnit_VACE(),
WanVideoUnit_UnifiedSequenceParallel(),
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
]
@@ -375,6 +376,19 @@ class WanVideoPipeline(BasePipeline):
)
def initialize_usp(self):
import torch.distributed as dist
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
dist.init_process_group(backend="nccl", init_method="env://")
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=1,
ulysses_degree=dist.get_world_size(),
)
torch.cuda.set_device(dist.get_rank())
def enable_usp(self):
from xfuser.core.distributed import get_sequence_parallel_world_size
from ..distributed.xdit_context_parallel import usp_attn_forward, usp_dit_forward
@@ -423,6 +437,7 @@ class WanVideoPipeline(BasePipeline):
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp: pipe.initialize_usp()
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
pipe.dit = model_manager.fetch_model("wan_video_dit")
pipe.vae = model_manager.fetch_model("wan_video_vae")
@@ -434,6 +449,9 @@ class WanVideoPipeline(BasePipeline):
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
pipe.prompter.fetch_models(pipe.text_encoder)
pipe.prompter.fetch_tokenizer(tokenizer_config.path)
# Unified Sequence Parallel
if use_usp: pipe.enable_usp()
return pipe
@@ -492,11 +510,11 @@ class WanVideoPipeline(BasePipeline):
# Inputs
inputs_posi = {
"prompt": prompt,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
}
inputs_nega = {
"negative_prompt": negative_prompt,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id,
"tea_cache_l1_thresh": tea_cache_l1_thresh, "tea_cache_model_id": tea_cache_model_id, "num_inference_steps": num_inference_steps,
}
inputs_shared = {
"input_image": input_image,
@@ -507,7 +525,7 @@ class WanVideoPipeline(BasePipeline):
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
"num_inference_steps": num_inference_steps, "sigma_shift": sigma_shift,
"sigma_shift": sigma_shift,
"motion_bucket_id": motion_bucket_id,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
@@ -811,6 +829,18 @@ class WanVideoUnit_VACE(PipelineUnit):
class WanVideoUnit_UnifiedSequenceParallel(PipelineUnit):
def __init__(self):
super().__init__(input_params=())
def process(self, pipe: WanVideoPipeline):
if hasattr(pipe, "use_unified_sequence_parallel"):
if pipe.use_unified_sequence_parallel:
return {"use_unified_sequence_parallel": True}
return {}
class WanVideoUnit_TeaCache(PipelineUnit):
def __init__(self):
super().__init__(