mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 10:48:11 +00:00
update model loader
This commit is contained in:
@@ -353,7 +353,8 @@ preset_models_on_modelscope = {
|
|||||||
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
||||||
],
|
],
|
||||||
# Qwen Prompt
|
# Qwen Prompt
|
||||||
"QwenPrompt": [
|
"QwenPrompt": {
|
||||||
|
"file_list": [
|
||||||
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
@@ -363,8 +364,13 @@ preset_models_on_modelscope = {
|
|||||||
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
||||||
],
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/QwenPrompt/qwen2-1.5b-instruct",
|
||||||
|
],
|
||||||
|
},
|
||||||
# Beautiful Prompt
|
# Beautiful Prompt
|
||||||
"BeautifulPrompt": [
|
"BeautifulPrompt": {
|
||||||
|
"file_list": [
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
@@ -372,8 +378,13 @@ preset_models_on_modelscope = {
|
|||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
||||||
],
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
||||||
|
],
|
||||||
|
},
|
||||||
# Omost prompt
|
# Omost prompt
|
||||||
"OmostPrompt":[
|
"OmostPrompt": {
|
||||||
|
"file_list": [
|
||||||
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
@@ -383,8 +394,13 @@ preset_models_on_modelscope = {
|
|||||||
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
||||||
],
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
||||||
|
],
|
||||||
|
},
|
||||||
# Translator
|
# Translator
|
||||||
"opus-mt-zh-en": [
|
"opus-mt-zh-en": {
|
||||||
|
"file_list": [
|
||||||
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
||||||
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
||||||
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
||||||
@@ -394,6 +410,10 @@ preset_models_on_modelscope = {
|
|||||||
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
||||||
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
||||||
],
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/translator/opus-mt-zh-en",
|
||||||
|
],
|
||||||
|
},
|
||||||
# IP-Adapter
|
# IP-Adapter
|
||||||
"IP-Adapter-SD": [
|
"IP-Adapter-SD": [
|
||||||
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
||||||
@@ -404,7 +424,8 @@ preset_models_on_modelscope = {
|
|||||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||||
],
|
],
|
||||||
# Kolors
|
# Kolors
|
||||||
"Kolors": [
|
"Kolors": {
|
||||||
|
"file_list": [
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||||
@@ -417,11 +438,18 @@ preset_models_on_modelscope = {
|
|||||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||||
],
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/kolors/Kolors/text_encoder",
|
||||||
|
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
||||||
|
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
||||||
|
],
|
||||||
|
},
|
||||||
"SDXL-vae-fp16-fix": [
|
"SDXL-vae-fp16-fix": [
|
||||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
||||||
],
|
],
|
||||||
# FLUX
|
# FLUX
|
||||||
"FLUX.1-dev": [
|
"FLUX.1-dev": {
|
||||||
|
"file_list": [
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
@@ -430,6 +458,30 @@ preset_models_on_modelscope = {
|
|||||||
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||||
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||||
],
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||||
|
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"FLUX.1-schnell": {
|
||||||
|
"file_list": [
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
||||||
|
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||||
|
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
||||||
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||||
|
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||||
|
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
||||||
|
],
|
||||||
|
},
|
||||||
# ESRGAN
|
# ESRGAN
|
||||||
"ESRGAN_x4": [
|
"ESRGAN_x4": [
|
||||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||||
@@ -439,7 +491,8 @@ preset_models_on_modelscope = {
|
|||||||
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
||||||
],
|
],
|
||||||
# CogVideo
|
# CogVideo
|
||||||
"CogVideoX-5B": [
|
"CogVideoX-5B": {
|
||||||
|
"file_list": [
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
||||||
@@ -450,6 +503,12 @@ preset_models_on_modelscope = {
|
|||||||
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
||||||
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
||||||
],
|
],
|
||||||
|
"load_path": [
|
||||||
|
"models/CogVideo/CogVideoX-5b/text_encoder",
|
||||||
|
"models/CogVideo/CogVideoX-5b/transformer",
|
||||||
|
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
||||||
|
],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
Preset_model_id: TypeAlias = Literal[
|
Preset_model_id: TypeAlias = Literal[
|
||||||
"HunyuanDiT",
|
"HunyuanDiT",
|
||||||
@@ -481,6 +540,7 @@ Preset_model_id: TypeAlias = Literal[
|
|||||||
"SDXL-vae-fp16-fix",
|
"SDXL-vae-fp16-fix",
|
||||||
"ControlNet_union_sdxl_promax",
|
"ControlNet_union_sdxl_promax",
|
||||||
"FLUX.1-dev",
|
"FLUX.1-dev",
|
||||||
|
"FLUX.1-schnell",
|
||||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||||
"QwenPrompt",
|
"QwenPrompt",
|
||||||
"OmostPrompt",
|
"OmostPrompt",
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ from ..configs.model_config import preset_models_on_huggingface, preset_models_o
|
|||||||
|
|
||||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
file_name = os.path.basename(origin_file_path)
|
||||||
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
|
if file_name in os.listdir(local_dir):
|
||||||
return
|
print(f" {file_name} has been already in {local_dir}.")
|
||||||
else:
|
else:
|
||||||
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||||
@@ -25,11 +25,10 @@ def download_from_huggingface(model_id, origin_file_path, local_dir):
|
|||||||
os.makedirs(local_dir, exist_ok=True)
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
file_name = os.path.basename(origin_file_path)
|
file_name = os.path.basename(origin_file_path)
|
||||||
if file_name in os.listdir(local_dir):
|
if file_name in os.listdir(local_dir):
|
||||||
return f"{file_name} has already been downloaded to {local_dir}."
|
print(f" {file_name} has been already in {local_dir}.")
|
||||||
else:
|
else:
|
||||||
print(f"Start downloading {os.path.join(local_dir, file_name)}")
|
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||||
|
|
||||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||||
target_file_path = os.path.join(local_dir, file_name)
|
target_file_path = os.path.join(local_dir, file_name)
|
||||||
if downloaded_file_path != target_file_path:
|
if downloaded_file_path != target_file_path:
|
||||||
@@ -51,16 +50,14 @@ website_to_download_fn = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def download_models(
|
def download_customized_models(
|
||||||
model_id_list: List[Preset_model_id] = [],
|
model_id,
|
||||||
|
origin_file_path,
|
||||||
|
local_dir,
|
||||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||||
):
|
):
|
||||||
print(f"Downloading models: {model_id_list}")
|
|
||||||
downloaded_files = []
|
downloaded_files = []
|
||||||
for model_id in model_id_list:
|
|
||||||
for website in downloading_priority:
|
for website in downloading_priority:
|
||||||
if model_id in website_to_preset_models[website]:
|
|
||||||
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
|
|
||||||
# Check if the file is downloaded.
|
# Check if the file is downloaded.
|
||||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||||
if file_to_download in downloaded_files:
|
if file_to_download in downloaded_files:
|
||||||
@@ -70,3 +67,45 @@ def download_models(
|
|||||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||||
downloaded_files.append(file_to_download)
|
downloaded_files.append(file_to_download)
|
||||||
return downloaded_files
|
return downloaded_files
|
||||||
|
|
||||||
|
|
||||||
|
def download_models(
|
||||||
|
model_id_list: List[Preset_model_id] = [],
|
||||||
|
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||||
|
):
|
||||||
|
print(f"Downloading models: {model_id_list}")
|
||||||
|
downloaded_files = []
|
||||||
|
load_files = []
|
||||||
|
|
||||||
|
for model_id in model_id_list:
|
||||||
|
for website in downloading_priority:
|
||||||
|
if model_id in website_to_preset_models[website]:
|
||||||
|
|
||||||
|
# Parse model metadata
|
||||||
|
model_metadata = website_to_preset_models[website][model_id]
|
||||||
|
if isinstance(model_metadata, list):
|
||||||
|
file_data = model_metadata
|
||||||
|
else:
|
||||||
|
file_data = model_metadata.get("file_list", [])
|
||||||
|
|
||||||
|
# Try downloading the model from this website.
|
||||||
|
model_files = []
|
||||||
|
for model_id, origin_file_path, local_dir in file_data:
|
||||||
|
# Check if the file is downloaded.
|
||||||
|
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||||
|
if file_to_download in downloaded_files:
|
||||||
|
continue
|
||||||
|
# Download
|
||||||
|
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||||
|
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||||
|
downloaded_files.append(file_to_download)
|
||||||
|
model_files.append(file_to_download)
|
||||||
|
|
||||||
|
# If the model is successfully downloaded, break.
|
||||||
|
if len(model_files) > 0:
|
||||||
|
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
|
||||||
|
model_files = model_metadata["load_path"]
|
||||||
|
load_files.extend(model_files)
|
||||||
|
break
|
||||||
|
|
||||||
|
return load_files
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from torch import Tensor
|
|||||||
from typing_extensions import Literal, TypeAlias
|
from typing_extensions import Literal, TypeAlias
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from .downloader import download_models, Preset_model_id, Preset_model_website
|
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
|
||||||
|
|
||||||
from .sd_text_encoder import SDTextEncoder
|
from .sd_text_encoder import SDTextEncoder
|
||||||
from .sd_unet import SDUNet
|
from .sd_unet import SDUNet
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
|
||||||
|
|
||||||
|
|
||||||
def prepare_extra_input(self, latents=None, guidance=0.0):
|
def prepare_extra_input(self, latents=None, guidance=1.0):
|
||||||
latent_image_ids = self.dit.prepare_image_ids(latents)
|
latent_image_ids = self.dit.prepare_image_ids(latents)
|
||||||
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
||||||
return {"image_ids": latent_image_ids, "guidance": guidance}
|
return {"image_ids": latent_image_ids, "guidance": guidance}
|
||||||
@@ -80,13 +80,13 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
mask_scales= None,
|
mask_scales= None,
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
cfg_scale=1.0,
|
cfg_scale=1.0,
|
||||||
embedded_guidance=0.0,
|
embedded_guidance=1.0,
|
||||||
input_image=None,
|
input_image=None,
|
||||||
denoising_strength=1.0,
|
denoising_strength=1.0,
|
||||||
height=1024,
|
height=1024,
|
||||||
width=1024,
|
width=1024,
|
||||||
num_inference_steps=30,
|
num_inference_steps=30,
|
||||||
t5_sequence_length=256,
|
t5_sequence_length=512,
|
||||||
tiled=False,
|
tiled=False,
|
||||||
tile_size=128,
|
tile_size=128,
|
||||||
tile_stride=64,
|
tile_stride=64,
|
||||||
|
|||||||
Reference in New Issue
Block a user