From 4c052e42bc2d4f456ac07250d076bc48169289a0 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 16 Jun 2025 15:43:39 +0800 Subject: [PATCH] fix usp download --- diffsynth/pipelines/wan_video_new.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 9086af2..1113980 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -166,10 +166,12 @@ class ModelConfig: offload_device: Optional[Union[str, torch.device]] = None offload_dtype: Optional[torch.dtype] = None - def download_if_necessary(self, local_model_path="./models", skip_download=False): + def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False): if self.path is None: if self.model_id is None or self.origin_file_pattern is None: raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""") + import torch.distributed as dist + skip_download = True if use_usp and dist.get_rank() != 0 else skip_download if not skip_download: downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id)) snapshot_download( @@ -179,6 +181,8 @@ class ModelConfig: ignore_file_pattern=downloaded_files, local_files_only=False ) + if use_usp: + dist.barrier(device_ids=[dist.get_rank()]) self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern)) if isinstance(self.path, list) and len(self.path) == 1: self.path = self.path[0] @@ -425,19 +429,21 @@ class WanVideoPipeline(BasePipeline): print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.") model_config.model_id = redirect_dict[model_config.origin_file_pattern] + # Initialize pipeline + pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) + if use_usp: pipe.initialize_usp() + # Download and load models model_manager = ModelManager() for model_config in model_configs: - model_config.download_if_necessary(local_model_path, skip_download=skip_download) + model_config.download_if_necessary(local_model_path, skip_download=skip_download, use_usp=use_usp) model_manager.load_model( model_config.path, device=model_config.offload_device or device, torch_dtype=model_config.offload_dtype or torch_dtype ) - # Initialize pipeline - pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) - if use_usp: pipe.initialize_usp() + # Load models pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder") pipe.dit = model_manager.fetch_model("wan_video_dit") pipe.vae = model_manager.fetch_model("wan_video_vae")