wan usp bug fix

This commit is contained in:
lzws
2026-01-12 22:08:48 +08:00
committed by GitHub
parent a236a17f17
commit e99cdcf3b8

View File

@@ -122,11 +122,15 @@ class WanVideoPipeline(BasePipeline):
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp:
from ..utils.xfuser import initialize_usp
initialize_usp(device)
import torch.distributed as dist
from ..core.device.npu_compatible_device import get_device_name, IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
if dist.is_available() and dist.is_initialized() and (IS_CUDA_AVAILABLE or IS_NPU_AVAILABLE):
device = get_device_name()
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models