mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
update model loader
This commit is contained in:
@@ -8,28 +8,27 @@ from ..configs.model_config import preset_models_on_huggingface, preset_models_o
|
||||
|
||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
||||
return
|
||||
file_name = os.path.basename(origin_file_path)
|
||||
if file_name in os.listdir(local_dir):
|
||||
print(f" {file_name} has been already in {local_dir}.")
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
file_name = os.path.basename(origin_file_path)
|
||||
if file_name in os.listdir(local_dir):
|
||||
return f"{file_name} has already been downloaded to {local_dir}."
|
||||
print(f" {file_name} has been already in {local_dir}.")
|
||||
else:
|
||||
print(f"Start downloading {os.path.join(local_dir, file_name)}")
|
||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, file_name)
|
||||
if downloaded_file_path != target_file_path:
|
||||
@@ -51,16 +50,47 @@ website_to_download_fn = {
|
||||
}
|
||||
|
||||
|
||||
def download_customized_models(
|
||||
model_id,
|
||||
origin_file_path,
|
||||
local_dir,
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
downloaded_files = []
|
||||
for website in downloading_priority:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
continue
|
||||
# Download
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
return downloaded_files
|
||||
|
||||
|
||||
def download_models(
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
print(f"Downloading models: {model_id_list}")
|
||||
downloaded_files = []
|
||||
load_files = []
|
||||
|
||||
for model_id in model_id_list:
|
||||
for website in downloading_priority:
|
||||
if model_id in website_to_preset_models[website]:
|
||||
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
||||
|
||||
# Parse model metadata
|
||||
model_metadata = website_to_preset_models[website][model_id]
|
||||
if isinstance(model_metadata, list):
|
||||
file_data = model_metadata
|
||||
else:
|
||||
file_data = model_metadata.get("file_list", [])
|
||||
|
||||
# Try downloading the model from this website.
|
||||
model_files = []
|
||||
for model_id, origin_file_path, local_dir in file_data:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
@@ -69,4 +99,13 @@ def download_models(
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
return downloaded_files
|
||||
model_files.append(file_to_download)
|
||||
|
||||
# If the model is successfully downloaded, break.
|
||||
if len(model_files) > 0:
|
||||
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
|
||||
model_files = model_metadata["load_path"]
|
||||
load_files.extend(model_files)
|
||||
break
|
||||
|
||||
return load_files
|
||||
|
||||
@@ -4,7 +4,7 @@ from torch import Tensor
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
|
||||
from .downloader import download_models, Preset_model_id, Preset_model_website
|
||||
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
|
||||
Reference in New Issue
Block a user