From ec352cfce21bb967f54fb0cf38a6863cf3f96cc6 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 8 Oct 2024 16:46:44 +0800 Subject: [PATCH 1/6] update model loader --- diffsynth/configs/model_config.py | 202 +++++++++++++++++++----------- diffsynth/models/downloader.py | 69 +++++++--- diffsynth/models/model_manager.py | 2 +- diffsynth/pipelines/flux_image.py | 6 +- 4 files changed, 189 insertions(+), 90 deletions(-) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 55c9270..27223e9 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -353,47 +353,67 @@ preset_models_on_modelscope = { ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"), ], # Qwen Prompt - "QwenPrompt": [ - ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ], + "QwenPrompt": { + "file_list": [ + ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ], + "load_path": [ + "models/QwenPrompt/qwen2-1.5b-instruct", + ], + }, # Beautiful Prompt - "BeautifulPrompt": [ - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ], + "BeautifulPrompt": { + "file_list": [ + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ], + "load_path": [ + "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd", + ], + }, # Omost prompt - "OmostPrompt":[ - ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ], + "OmostPrompt": { + "file_list": [ + ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ], + "load_path": [ + "models/OmostPrompt/omost-llama-3-8b-4bits", + ], + }, # Translator - "opus-mt-zh-en": [ - ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"), - ], + "opus-mt-zh-en": { + "file_list": [ + ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"), + ], + "load_path": [ + "models/translator/opus-mt-zh-en", + ], + }, # IP-Adapter "IP-Adapter-SD": [ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"), @@ -404,32 +424,64 @@ preset_models_on_modelscope = { ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"), ], # Kolors - "Kolors": [ - ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"), - ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), - ], + "Kolors": { + "file_list": [ + ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"), + ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), + ], + "load_path": [ + "models/kolors/Kolors/text_encoder", + "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors", + "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors", + ], + }, "SDXL-vae-fp16-fix": [ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix") ], # FLUX - "FLUX.1-dev": [ - ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), - ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"), - ], + "FLUX.1-dev": { + "file_list": [ + ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), + ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"), + ], + "load_path": [ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" + ], + }, + "FLUX.1-schnell": { + "file_list": [ + ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), + ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"), + ], + "load_path": [ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors" + ], + }, # ESRGAN "ESRGAN_x4": [ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"), @@ -439,17 +491,24 @@ preset_models_on_modelscope = { ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"), ], # CogVideo - "CogVideoX-5B": [ - ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"), - ], + "CogVideoX-5B": { + "file_list": [ + ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"), + ], + "load_path": [ + "models/CogVideo/CogVideoX-5b/text_encoder", + "models/CogVideo/CogVideoX-5b/transformer", + "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors", + ], + }, } Preset_model_id: TypeAlias = Literal[ "HunyuanDiT", @@ -481,6 +540,7 @@ Preset_model_id: TypeAlias = Literal[ "SDXL-vae-fp16-fix", "ControlNet_union_sdxl_promax", "FLUX.1-dev", + "FLUX.1-schnell", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "QwenPrompt", "OmostPrompt", diff --git a/diffsynth/models/downloader.py b/diffsynth/models/downloader.py index 6801d71..6c726f6 100644 --- a/diffsynth/models/downloader.py +++ b/diffsynth/models/downloader.py @@ -8,28 +8,27 @@ from ..configs.model_config import preset_models_on_huggingface, preset_models_o def download_from_modelscope(model_id, origin_file_path, local_dir): os.makedirs(local_dir, exist_ok=True) - if os.path.basename(origin_file_path) in os.listdir(local_dir): - print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.") - return + file_name = os.path.basename(origin_file_path) + if file_name in os.listdir(local_dir): + print(f" {file_name} has been already in {local_dir}.") else: - print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") - snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir) - downloaded_file_path = os.path.join(local_dir, origin_file_path) - target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1]) - if downloaded_file_path != target_file_path: - shutil.move(downloaded_file_path, target_file_path) - shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0])) + print(f" Start downloading {os.path.join(local_dir, file_name)}") + snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir) + downloaded_file_path = os.path.join(local_dir, origin_file_path) + target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1]) + if downloaded_file_path != target_file_path: + shutil.move(downloaded_file_path, target_file_path) + shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0])) def download_from_huggingface(model_id, origin_file_path, local_dir): os.makedirs(local_dir, exist_ok=True) file_name = os.path.basename(origin_file_path) if file_name in os.listdir(local_dir): - return f"{file_name} has already been downloaded to {local_dir}." + print(f" {file_name} has been already in {local_dir}.") else: - print(f"Start downloading {os.path.join(local_dir, file_name)}") + print(f" Start downloading {os.path.join(local_dir, file_name)}") hf_hub_download(model_id, origin_file_path, local_dir=local_dir) - downloaded_file_path = os.path.join(local_dir, origin_file_path) target_file_path = os.path.join(local_dir, file_name) if downloaded_file_path != target_file_path: @@ -51,16 +50,47 @@ website_to_download_fn = { } +def download_customized_models( + model_id, + origin_file_path, + local_dir, + downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"], +): + downloaded_files = [] + for website in downloading_priority: + # Check if the file is downloaded. + file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path)) + if file_to_download in downloaded_files: + continue + # Download + website_to_download_fn[website](model_id, origin_file_path, local_dir) + if os.path.basename(origin_file_path) in os.listdir(local_dir): + downloaded_files.append(file_to_download) + return downloaded_files + + def download_models( model_id_list: List[Preset_model_id] = [], downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"], ): print(f"Downloading models: {model_id_list}") downloaded_files = [] + load_files = [] + for model_id in model_id_list: for website in downloading_priority: if model_id in website_to_preset_models[website]: - for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]: + + # Parse model metadata + model_metadata = website_to_preset_models[website][model_id] + if isinstance(model_metadata, list): + file_data = model_metadata + else: + file_data = model_metadata.get("file_list", []) + + # Try downloading the model from this website. + model_files = [] + for model_id, origin_file_path, local_dir in file_data: # Check if the file is downloaded. file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path)) if file_to_download in downloaded_files: @@ -69,4 +99,13 @@ def download_models( website_to_download_fn[website](model_id, origin_file_path, local_dir) if os.path.basename(origin_file_path) in os.listdir(local_dir): downloaded_files.append(file_to_download) - return downloaded_files + model_files.append(file_to_download) + + # If the model is successfully downloaded, break. + if len(model_files) > 0: + if isinstance(model_metadata, dict) and "load_path" in model_metadata: + model_files = model_metadata["load_path"] + load_files.extend(model_files) + break + + return load_files diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 7f5eef8..0e8a51e 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -4,7 +4,7 @@ from torch import Tensor from typing_extensions import Literal, TypeAlias from typing import List -from .downloader import download_models, Preset_model_id, Preset_model_website +from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website from .sd_text_encoder import SDTextEncoder from .sd_unet import SDUNet diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 5cd57f1..9ccdeb9 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -65,7 +65,7 @@ class FluxImagePipeline(BasePipeline): return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} - def prepare_extra_input(self, latents=None, guidance=0.0): + def prepare_extra_input(self, latents=None, guidance=1.0): latent_image_ids = self.dit.prepare_image_ids(latents) guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) return {"image_ids": latent_image_ids, "guidance": guidance} @@ -80,13 +80,13 @@ class FluxImagePipeline(BasePipeline): mask_scales= None, negative_prompt="", cfg_scale=1.0, - embedded_guidance=0.0, + embedded_guidance=1.0, input_image=None, denoising_strength=1.0, height=1024, width=1024, num_inference_steps=30, - t5_sequence_length=256, + t5_sequence_length=512, tiled=False, tile_size=128, tile_stride=64, From 41ea2f811a57fbff76e7aacc6266d1fa8e4a154f Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 8 Oct 2024 18:23:39 +0800 Subject: [PATCH 2/6] update ESRGAN --- diffsynth/extensions/ESRGAN/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/diffsynth/extensions/ESRGAN/__init__.py b/diffsynth/extensions/ESRGAN/__init__.py index 00b90d1..94aff4c 100644 --- a/diffsynth/extensions/ESRGAN/__init__.py +++ b/diffsynth/extensions/ESRGAN/__init__.py @@ -107,6 +107,12 @@ class ESRGAN(torch.nn.Module): @torch.no_grad() def upscale(self, images, batch_size=4, progress_bar=lambda x:x): + if not isinstance(images, list): + images = [images] + is_single_image = True + else: + is_single_image = False + # Preprocess input_tensor = self.process_images(images) @@ -126,4 +132,6 @@ class ESRGAN(torch.nn.Module): # To images output_images = self.decode_images(output_tensor) + if is_single_image: + output_images = output_images[0] return output_images From fa0fa95bb60e7eaa5f656ae8e9e4cecf75afd5c2 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 10 Oct 2024 17:05:04 +0800 Subject: [PATCH 3/6] update flux pipeline --- diffsynth/models/flux_dit.py | 1 + diffsynth/pipelines/base.py | 8 ++++++++ diffsynth/pipelines/flux_image.py | 9 +++++---- diffsynth/prompters/flux_prompter.py | 2 +- examples/image_synthesis/flux_text_to_image.py | 16 ++++++++-------- .../flux_text_to_image_low_vram.py | 18 +++++++++--------- 6 files changed, 32 insertions(+), 22 deletions(-) diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 648d23f..ffe55b8 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -364,6 +364,7 @@ class FluxDiT(torch.nn.Module): conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) if self.guidance_embedder is not None: + guidance = guidance * 1000 conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) prompt_emb = self.context_embedder(prompt_emb) image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index 956e9ba..55cfc14 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -65,9 +65,11 @@ class BasePipeline(torch.nn.Module): mask_scales += [100.0] * len(extended_prompt_dict.get("masks", [])) return prompt, local_prompts, masks, mask_scales + def enable_cpu_offload(self): self.cpu_offload = True + def load_models_to_device(self, loadmodel_names=[]): # only load models to device if cpu_offload is enabled if not self.cpu_offload: @@ -85,3 +87,9 @@ class BasePipeline(torch.nn.Module): model.to(self.device) # fresh the cuda cache torch.cuda.empty_cache() + + + def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16): + generator = None if seed is None else torch.Generator(device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + return noise diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 9ccdeb9..06f5649 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -58,7 +58,7 @@ class FluxImagePipeline(BasePipeline): return image - def encode_prompt(self, prompt, positive=True, t5_sequence_length=256): + def encode_prompt(self, prompt, positive=True, t5_sequence_length=512): prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt( prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length ) @@ -80,7 +80,7 @@ class FluxImagePipeline(BasePipeline): mask_scales= None, negative_prompt="", cfg_scale=1.0, - embedded_guidance=1.0, + embedded_guidance=3.5, input_image=None, denoising_strength=1.0, height=1024, @@ -90,6 +90,7 @@ class FluxImagePipeline(BasePipeline): tiled=False, tile_size=128, tile_stride=64, + seed=None, progress_bar_cmd=tqdm, progress_bar_st=None, ): @@ -104,10 +105,10 @@ class FluxImagePipeline(BasePipeline): self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) latents = self.encode_image(image, **tiler_kwargs) - noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) + noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: - latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) + latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) # Extend prompt self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) diff --git a/diffsynth/prompters/flux_prompter.py b/diffsynth/prompters/flux_prompter.py index 07d6dc7..9a6bd7d 100644 --- a/diffsynth/prompters/flux_prompter.py +++ b/diffsynth/prompters/flux_prompter.py @@ -57,7 +57,7 @@ class FluxPrompter(BasePrompter): prompt, positive=True, device="cuda", - t5_sequence_length=256, + t5_sequence_length=512, ): prompt = self.process_prompt(prompt, positive=positive) diff --git a/examples/image_synthesis/flux_text_to_image.py b/examples/image_synthesis/flux_text_to_image.py index a2e5199..6a50df3 100644 --- a/examples/image_synthesis/flux_text_to_image.py +++ b/examples/image_synthesis/flux_text_to_image.py @@ -12,30 +12,30 @@ model_manager.load_models([ ]) pipe = FluxImagePipeline.from_model_manager(model_manager) -prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." -negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," # Disable classifier-free guidance (consistent with the original implementation of FLUX.1) -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5 + num_inference_steps=50, embedded_guidance=3.5 ) image.save("image_1024.jpg") # Enable classifier-free guidance -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, negative_prompt=negative_prompt, - num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5 + num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5 ) image.save("image_1024_cfg.jpg") # Highres-fix -torch.manual_seed(7) +torch.manual_seed(10) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5, + num_inference_steps=50, embedded_guidance=3.5, input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True ) image.save("image_2048_highres.jpg") diff --git a/examples/image_synthesis/flux_text_to_image_low_vram.py b/examples/image_synthesis/flux_text_to_image_low_vram.py index b98929c..985f009 100644 --- a/examples/image_synthesis/flux_text_to_image_low_vram.py +++ b/examples/image_synthesis/flux_text_to_image_low_vram.py @@ -22,30 +22,30 @@ pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda") pipe.enable_cpu_offload() pipe.dit.quantize() -prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." -negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," # Disable classifier-free guidance (consistent with the original implementation of FLUX.1) -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5 + num_inference_steps=50, embedded_guidance=3.5 ) image.save("image_1024.jpg") # Enable classifier-free guidance -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, negative_prompt=negative_prompt, - num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5 + num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5 ) image.save("image_1024_cfg.jpg") # Highres-fix -torch.manual_seed(7) +torch.manual_seed(10) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5, + num_inference_steps=50, embedded_guidance=3.5, input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True ) -image.save("image_2048_highres.jpg") \ No newline at end of file +image.save("image_2048_highres.jpg") From a0d1d5bcea4391c00ef4990d814d511e0a666518 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 10 Oct 2024 17:25:55 +0800 Subject: [PATCH 4/6] update examples --- README.md | 2 +- examples/image_synthesis/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c29a12e..ab1dc82 100644 --- a/README.md +++ b/README.md @@ -164,7 +164,7 @@ LoRA fine-tuning is supported in [`examples/train`](./examples/train/). |FLUX|Stable Diffusion 3| |-|-| -|![image_1024_cfg](https://github.com/user-attachments/assets/6af5b106-0673-4e58-9213-cd9157eef4c0)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)| +|![image_1024_cfg](https://github.com/user-attachments/assets/984561e9-553d-4952-9443-79ce144f379f)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)| |Kolors|Hunyuan-DiT| |-|-| diff --git a/examples/image_synthesis/README.md b/examples/image_synthesis/README.md index d2a52c5..65dbe66 100644 --- a/examples/image_synthesis/README.md +++ b/examples/image_synthesis/README.md @@ -10,7 +10,7 @@ The original version of FLUX doesn't support classifier-free guidance; however, |1024*1024 (original)|1024*1024 (classifier-free guidance)|2048*2048 (highres-fix)| |-|-|-| -|![image_1024](https://github.com/user-attachments/assets/ce01327f-068f-45f5-aba9-0fa45eb26199)|![image_1024_cfg](https://github.com/user-attachments/assets/6af5b106-0673-4e58-9213-cd9157eef4c0)|![image_2048_highres](https://github.com/user-attachments/assets/a4bb776f-d9f0-4450-968c-c5d090a3ab4c)| +|![image_1024](https://github.com/user-attachments/assets/9cbd1f6f-4ac4-4f8b-bf46-218d812a15a0)|![image_1024_cfg](https://github.com/user-attachments/assets/984561e9-553d-4952-9443-79ce144f379f)|![image_2048_highres](https://github.com/user-attachments/assets/2e92b2f8-c177-454f-84f6-f6f5d3aaeeff)| ### Example: Stable Diffusion From 66873d7d6468887307598b9f3b38b7417d065c03 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 10 Oct 2024 18:23:43 +0800 Subject: [PATCH 5/6] update examples --- examples/ExVideo/ExVideo_cogvideox_test.py | 21 +++++++++++++++++++++ examples/ExVideo/README.md | 12 ++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) create mode 100644 examples/ExVideo/ExVideo_cogvideox_test.py diff --git a/examples/ExVideo/ExVideo_cogvideox_test.py b/examples/ExVideo/ExVideo_cogvideox_test.py new file mode 100644 index 0000000..4d0fd3a --- /dev/null +++ b/examples/ExVideo/ExVideo_cogvideox_test.py @@ -0,0 +1,21 @@ +from diffsynth import ModelManager, CogVideoPipeline, save_video, download_models +import torch + + +download_models(["CogVideoX-5B", "ExVideo-CogVideoX-LoRA-129f-v1"]) +model_manager = ModelManager(torch_dtype=torch.bfloat16) +model_manager.load_models([ + "models/CogVideo/CogVideoX-5b/text_encoder", + "models/CogVideo/CogVideoX-5b/transformer", + "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors", +]) +model_manager.load_lora("models/lora/ExVideo-CogVideoX-LoRA-129f-v1.safetensors") +pipe = CogVideoPipeline.from_model_manager(model_manager) + +torch.manual_seed(6) +video = pipe( + prompt="an astronaut riding a horse on Mars.", + height=480, width=720, num_frames=129, + cfg_scale=7.0, num_inference_steps=100, +) +save_video(video, "video_with_lora.mp4", fps=8, quality=5) diff --git a/examples/ExVideo/README.md b/examples/ExVideo/README.md index cb57d48..32f8d80 100644 --- a/examples/ExVideo/README.md +++ b/examples/ExVideo/README.md @@ -4,11 +4,19 @@ ExVideo is a post-tuning technique aimed at enhancing the capability of video ge * [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/) * [Technical report](https://arxiv.org/abs/2406.14130) -* [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) -* Extended models +* **[New]** Extended models (ExVideo-CogVideoX) + * [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) + * [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) +* Extended models (ExVideo-SVD) * [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) * [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1) +## Example: Text-to-video via extended CogVideoX-5B + +Generate a video using CogVideoX-5B and our extension module. See [ExVideo_cogvideox_test.py](./ExVideo_cogvideox_test.py). + +https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e + ## Example: Text-to-video via extended Stable Video Diffusion Generate a video using a text-to-image model and our image-to-video model. See [ExVideo_svd_test.py](./ExVideo_svd_test.py). From e5c72ba1f2fa59f72c5ae1dbdd72afaab58eb484 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 10 Oct 2024 18:26:37 +0800 Subject: [PATCH 6/6] update examples --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ab1dc82..aecf269 100644 --- a/README.md +++ b/README.md @@ -137,10 +137,11 @@ https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006 #### Long Video Synthesis -We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/) +We trained extended video synthesis models, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/) https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc +https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e #### Toon Shading