From 31369bab154ed69e0fad1300acbd58f9fbc730ba Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 19 Jun 2025 10:04:24 +0800 Subject: [PATCH] update import --- diffsynth/pipelines/wan_video_new.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 668b16f..4a7a2a2 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -170,8 +170,9 @@ class ModelConfig: if self.path is None: if self.model_id is None or self.origin_file_pattern is None: raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") - import torch.distributed as dist - skip_download = True if use_usp and dist.get_rank() != 0 else skip_download + if use_usp: + import torch.distributed as dist + skip_download = dist.get_rank() != 0 if not skip_download: downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id)) snapshot_download( @@ -182,6 +183,7 @@ class ModelConfig: local_files_only=False ) if use_usp: + import torch.distributed as dist dist.barrier(device_ids=[dist.get_rank()]) self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) if isinstance(self.path, list) and len(self.path) == 1: