fix usp download

This commit is contained in:
mi804
2025-06-16 15:43:39 +08:00
parent c164519ef1
commit 4c052e42bc

View File

@@ -166,10 +166,12 @@ class ModelConfig:
offload_device: Optional[Union[str, torch.device]] = None
offload_dtype: Optional[torch.dtype] = None
def download_if_necessary(self, local_model_path="./models", skip_download=False):
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
if self.path is None:
if self.model_id is None or self.origin_file_pattern is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
import torch.distributed as dist
skip_download = True if use_usp and dist.get_rank() != 0 else skip_download
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
snapshot_download(
@@ -179,6 +181,8 @@ class ModelConfig:
ignore_file_pattern=downloaded_files,
local_files_only=False
)
if use_usp:
dist.barrier(device_ids=[dist.get_rank()])
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]
@@ -425,19 +429,21 @@ class WanVideoPipeline(BasePipeline):
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to ({redirect_dict[model_config.origin_file_pattern]}, {model_config.origin_file_pattern}). You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern]
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp: pipe.initialize_usp()
# Download and load models
model_manager = ModelManager()
for model_config in model_configs:
model_config.download_if_necessary(local_model_path, skip_download=skip_download)
model_config.download_if_necessary(local_model_path, skip_download=skip_download, use_usp=use_usp)
model_manager.load_model(
model_config.path,
device=model_config.offload_device or device,
torch_dtype=model_config.offload_dtype or torch_dtype
)
# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
if use_usp: pipe.initialize_usp()
# Load models
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
pipe.dit = model_manager.fetch_model("wan_video_dit")
pipe.vae = model_manager.fetch_model("wan_video_vae")