refine examples

This commit is contained in:
Artiprocher
2025-06-24 19:17:43 +08:00
parent 3eb7e7530e
commit c8ad643374
8 changed files with 12 additions and 25 deletions

View File

@@ -169,7 +169,7 @@ class ModelConfig:
def download_if_necessary(self, local_model_path="./models", skip_download=False, use_usp=False):
if self.path is None:
# Check model_id and origin_file_pattern
if self.model_id is None or self.origin_file_pattern is None:
if self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`.""")
# Skip if not in rank 0
@@ -178,7 +178,11 @@ class ModelConfig:
skip_download = dist.get_rank() != 0
# Check whether the origin path is a folder
if isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
if self.origin_file_pattern is None:
self.origin_file_pattern = ""
allow_file_pattern = None
is_folder = True
elif isinstance(self.origin_file_pattern, str) and self.origin_file_pattern.endswith("/"):
allow_file_pattern = self.origin_file_pattern + "*"
is_folder = True
else: