This commit is contained in:
Artiprocher
2025-11-19 20:22:21 +08:00
parent 6ad8d73717
commit eeb55a0ce6
88 changed files with 3113 additions and 78 deletions

View File

@@ -2,6 +2,7 @@ import torch, glob, os
from typing import Optional, Union
from dataclasses import dataclass
from modelscope import snapshot_download
from huggingface_hub import snapshot_download as hf_snapshot_download
from typing import Optional
@@ -10,7 +11,7 @@ class ModelConfig:
path: Union[str, list[str]] = None
model_id: str = None
origin_file_pattern: Union[str, list[str]] = None
download_resource: str = "ModelScope"
download_resource: str = None
local_model_path: str = None
skip_download: bool = None
offload_device: Optional[Union[str, torch.device]] = None
@@ -30,13 +31,29 @@ class ModelConfig:
def download(self):
origin_file_pattern = self.origin_file_pattern + ("*" if self.origin_file_pattern.endswith("/") else "")
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_file_pattern=self.origin_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
if self.download_resource is None:
if os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE') is not None:
self.download_resource = os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE')
else:
self.download_resource = "modelscope"
if self.download_resource.lower() == "modelscope":
snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_file_pattern=self.origin_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
elif self.download_resource.lower() == "huggingface":
hf_snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_patterns=self.origin_file_pattern,
ignore_patterns=downloaded_files,
local_files_only=False
)
else:
raise ValueError("`download_resource` should be `modelscope` or `huggingface`.")
def require_downloading(self):
if self.path is not None: