diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 55c9270..27223e9 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -353,47 +353,67 @@ preset_models_on_modelscope = { ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"), ], # Qwen Prompt - "QwenPrompt": [ - ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ], + "QwenPrompt": { + "file_list": [ + ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ], + "load_path": [ + "models/QwenPrompt/qwen2-1.5b-instruct", + ], + }, # Beautiful Prompt - "BeautifulPrompt": [ - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ], + "BeautifulPrompt": { + "file_list": [ + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ], + "load_path": [ + "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd", + ], + }, # Omost prompt - "OmostPrompt":[ - ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ], + "OmostPrompt": { + "file_list": [ + ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ], + "load_path": [ + "models/OmostPrompt/omost-llama-3-8b-4bits", + ], + }, # Translator - "opus-mt-zh-en": [ - ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"), - ], + "opus-mt-zh-en": { + "file_list": [ + ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"), + ], + "load_path": [ + "models/translator/opus-mt-zh-en", + ], + }, # IP-Adapter "IP-Adapter-SD": [ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"), @@ -404,32 +424,64 @@ preset_models_on_modelscope = { ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"), ], # Kolors - "Kolors": [ - ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"), - ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), - ], + "Kolors": { + "file_list": [ + ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"), + ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), + ], + "load_path": [ + "models/kolors/Kolors/text_encoder", + "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors", + "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors", + ], + }, "SDXL-vae-fp16-fix": [ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix") ], # FLUX - "FLUX.1-dev": [ - ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), - ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"), - ], + "FLUX.1-dev": { + "file_list": [ + ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), + ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"), + ], + "load_path": [ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" + ], + }, + "FLUX.1-schnell": { + "file_list": [ + ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), + ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"), + ], + "load_path": [ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors" + ], + }, # ESRGAN "ESRGAN_x4": [ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"), @@ -439,17 +491,24 @@ preset_models_on_modelscope = { ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"), ], # CogVideo - "CogVideoX-5B": [ - ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"), - ], + "CogVideoX-5B": { + "file_list": [ + ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"), + ], + "load_path": [ + "models/CogVideo/CogVideoX-5b/text_encoder", + "models/CogVideo/CogVideoX-5b/transformer", + "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors", + ], + }, } Preset_model_id: TypeAlias = Literal[ "HunyuanDiT", @@ -481,6 +540,7 @@ Preset_model_id: TypeAlias = Literal[ "SDXL-vae-fp16-fix", "ControlNet_union_sdxl_promax", "FLUX.1-dev", + "FLUX.1-schnell", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "QwenPrompt", "OmostPrompt", diff --git a/diffsynth/models/downloader.py b/diffsynth/models/downloader.py index 6801d71..6c726f6 100644 --- a/diffsynth/models/downloader.py +++ b/diffsynth/models/downloader.py @@ -8,28 +8,27 @@ from ..configs.model_config import preset_models_on_huggingface, preset_models_o def download_from_modelscope(model_id, origin_file_path, local_dir): os.makedirs(local_dir, exist_ok=True) - if os.path.basename(origin_file_path) in os.listdir(local_dir): - print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.") - return + file_name = os.path.basename(origin_file_path) + if file_name in os.listdir(local_dir): + print(f" {file_name} has been already in {local_dir}.") else: - print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") - snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir) - downloaded_file_path = os.path.join(local_dir, origin_file_path) - target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1]) - if downloaded_file_path != target_file_path: - shutil.move(downloaded_file_path, target_file_path) - shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0])) + print(f" Start downloading {os.path.join(local_dir, file_name)}") + snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir) + downloaded_file_path = os.path.join(local_dir, origin_file_path) + target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1]) + if downloaded_file_path != target_file_path: + shutil.move(downloaded_file_path, target_file_path) + shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0])) def download_from_huggingface(model_id, origin_file_path, local_dir): os.makedirs(local_dir, exist_ok=True) file_name = os.path.basename(origin_file_path) if file_name in os.listdir(local_dir): - return f"{file_name} has already been downloaded to {local_dir}." + print(f" {file_name} has been already in {local_dir}.") else: - print(f"Start downloading {os.path.join(local_dir, file_name)}") + print(f" Start downloading {os.path.join(local_dir, file_name)}") hf_hub_download(model_id, origin_file_path, local_dir=local_dir) - downloaded_file_path = os.path.join(local_dir, origin_file_path) target_file_path = os.path.join(local_dir, file_name) if downloaded_file_path != target_file_path: @@ -51,16 +50,47 @@ website_to_download_fn = { } +def download_customized_models( + model_id, + origin_file_path, + local_dir, + downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"], +): + downloaded_files = [] + for website in downloading_priority: + # Check if the file is downloaded. + file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path)) + if file_to_download in downloaded_files: + continue + # Download + website_to_download_fn[website](model_id, origin_file_path, local_dir) + if os.path.basename(origin_file_path) in os.listdir(local_dir): + downloaded_files.append(file_to_download) + return downloaded_files + + def download_models( model_id_list: List[Preset_model_id] = [], downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"], ): print(f"Downloading models: {model_id_list}") downloaded_files = [] + load_files = [] + for model_id in model_id_list: for website in downloading_priority: if model_id in website_to_preset_models[website]: - for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]: + + # Parse model metadata + model_metadata = website_to_preset_models[website][model_id] + if isinstance(model_metadata, list): + file_data = model_metadata + else: + file_data = model_metadata.get("file_list", []) + + # Try downloading the model from this website. + model_files = [] + for model_id, origin_file_path, local_dir in file_data: # Check if the file is downloaded. file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path)) if file_to_download in downloaded_files: @@ -69,4 +99,13 @@ def download_models( website_to_download_fn[website](model_id, origin_file_path, local_dir) if os.path.basename(origin_file_path) in os.listdir(local_dir): downloaded_files.append(file_to_download) - return downloaded_files + model_files.append(file_to_download) + + # If the model is successfully downloaded, break. + if len(model_files) > 0: + if isinstance(model_metadata, dict) and "load_path" in model_metadata: + model_files = model_metadata["load_path"] + load_files.extend(model_files) + break + + return load_files diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 7f5eef8..0e8a51e 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -4,7 +4,7 @@ from torch import Tensor from typing_extensions import Literal, TypeAlias from typing import List -from .downloader import download_models, Preset_model_id, Preset_model_website +from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website from .sd_text_encoder import SDTextEncoder from .sd_unet import SDUNet diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 5cd57f1..9ccdeb9 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -65,7 +65,7 @@ class FluxImagePipeline(BasePipeline): return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} - def prepare_extra_input(self, latents=None, guidance=0.0): + def prepare_extra_input(self, latents=None, guidance=1.0): latent_image_ids = self.dit.prepare_image_ids(latents) guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) return {"image_ids": latent_image_ids, "guidance": guidance} @@ -80,13 +80,13 @@ class FluxImagePipeline(BasePipeline): mask_scales= None, negative_prompt="", cfg_scale=1.0, - embedded_guidance=0.0, + embedded_guidance=1.0, input_image=None, denoising_strength=1.0, height=1024, width=1024, num_inference_steps=30, - t5_sequence_length=256, + t5_sequence_length=512, tiled=False, tile_size=128, tile_stride=64,