mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
wan usp bug fix
This commit is contained in:
@@ -122,11 +122,15 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
|
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
|
||||||
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]
|
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]
|
||||||
|
|
||||||
# Initialize pipeline
|
|
||||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
|
||||||
if use_usp:
|
if use_usp:
|
||||||
from ..utils.xfuser import initialize_usp
|
from ..utils.xfuser import initialize_usp
|
||||||
initialize_usp(device)
|
initialize_usp(device)
|
||||||
|
import torch.distributed as dist
|
||||||
|
from ..core.device.npu_compatible_device import get_device_name, IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
|
||||||
|
if dist.is_available() and dist.is_initialized() and (IS_CUDA_AVAILABLE or IS_NPU_AVAILABLE):
|
||||||
|
device = get_device_name()
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
# Fetch models
|
# Fetch models
|
||||||
|
|||||||
Reference in New Issue
Block a user