mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
update import
This commit is contained in:
@@ -170,8 +170,9 @@ class ModelConfig:
|
||||
if self.path is None:
|
||||
if self.model_id is None or self.origin_file_pattern is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
|
||||
import torch.distributed as dist
|
||||
skip_download = True if use_usp and dist.get_rank() != 0 else skip_download
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
skip_download = dist.get_rank() != 0
|
||||
if not skip_download:
|
||||
downloaded_files = glob.glob(self.origin_file_pattern, root_dir=os.path.join(local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
@@ -182,6 +183,7 @@ class ModelConfig:
|
||||
local_files_only=False
|
||||
)
|
||||
if use_usp:
|
||||
import torch.distributed as dist
|
||||
dist.barrier(device_ids=[dist.get_rank()])
|
||||
self.path = glob.glob(os.path.join(local_model_path, self.model_id, self.origin_file_pattern))
|
||||
if isinstance(self.path, list) and len(self.path) == 1:
|
||||
|
||||
Reference in New Issue
Block a user