update import

This commit is contained in:
Artiprocher
2025-06-19 10:04:24 +08:00
parent 551721658b
commit 31369bab15

View File

@@ -170,8 +170,9 @@ class ModelConfig:
if self.path is None:
if self.model_id is None or self.origin_file_pattern is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
import torch.distributed as dist
skip_download = True if use_usp and dist.get_rank() != 0 else skip_download
if use_usp:
import torch.distributed as dist
skip_download = dist.get_rank() != 0
if not skip_download:
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
snapshot_download(
@@ -182,6 +183,7 @@ class ModelConfig:
local_files_only=False
)
if use_usp:
import torch.distributed as dist
dist.barrier(device_ids=[dist.get_rank()])
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1: