diff --git a/README.md b/README.md index 9e7e54e..ccda9cf 100644 --- a/README.md +++ b/README.md @@ -139,10 +139,11 @@ https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006 #### Long Video Synthesis -We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/) +We trained extended video synthesis models, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/) https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc +https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e #### Toon Shading @@ -166,7 +167,7 @@ LoRA fine-tuning is supported in [`examples/train`](./examples/train/). |FLUX|Stable Diffusion 3| |-|-| -|![image_1024_cfg](https://github.com/user-attachments/assets/6af5b106-0673-4e58-9213-cd9157eef4c0)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)| +|![image_1024_cfg](https://github.com/user-attachments/assets/984561e9-553d-4952-9443-79ce144f379f)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)| |Kolors|Hunyuan-DiT| |-|-| diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 55c9270..27223e9 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -353,47 +353,67 @@ preset_models_on_modelscope = { ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"), ], # Qwen Prompt - "QwenPrompt": [ - ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"), - ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"), - ], + "QwenPrompt": { + "file_list": [ + ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"), + ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"), + ], + "load_path": [ + "models/QwenPrompt/qwen2-1.5b-instruct", + ], + }, # Beautiful Prompt - "BeautifulPrompt": [ - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), - ], + "BeautifulPrompt": { + "file_list": [ + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), + ], + "load_path": [ + "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd", + ], + }, # Omost prompt - "OmostPrompt":[ - ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), - ], + "OmostPrompt": { + "file_list": [ + ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ], + "load_path": [ + "models/OmostPrompt/omost-llama-3-8b-4bits", + ], + }, # Translator - "opus-mt-zh-en": [ - ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"), - ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"), - ], + "opus-mt-zh-en": { + "file_list": [ + ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"), + ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"), + ], + "load_path": [ + "models/translator/opus-mt-zh-en", + ], + }, # IP-Adapter "IP-Adapter-SD": [ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"), @@ -404,32 +424,64 @@ preset_models_on_modelscope = { ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"), ], # Kolors - "Kolors": [ - ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"), - ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"), - ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), - ], + "Kolors": { + "file_list": [ + ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"), + ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), + ], + "load_path": [ + "models/kolors/Kolors/text_encoder", + "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors", + "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors", + ], + }, "SDXL-vae-fp16-fix": [ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix") ], # FLUX - "FLUX.1-dev": [ - ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), - ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), - ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"), - ], + "FLUX.1-dev": { + "file_list": [ + ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), + ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"), + ], + "load_path": [ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" + ], + }, + "FLUX.1-schnell": { + "file_list": [ + ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), + ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), + ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"), + ], + "load_path": [ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors" + ], + }, # ESRGAN "ESRGAN_x4": [ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"), @@ -439,17 +491,24 @@ preset_models_on_modelscope = { ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"), ], # CogVideo - "CogVideoX-5B": [ - ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), - ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), - ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"), - ], + "CogVideoX-5B": { + "file_list": [ + ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"), + ], + "load_path": [ + "models/CogVideo/CogVideoX-5b/text_encoder", + "models/CogVideo/CogVideoX-5b/transformer", + "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors", + ], + }, } Preset_model_id: TypeAlias = Literal[ "HunyuanDiT", @@ -481,6 +540,7 @@ Preset_model_id: TypeAlias = Literal[ "SDXL-vae-fp16-fix", "ControlNet_union_sdxl_promax", "FLUX.1-dev", + "FLUX.1-schnell", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "QwenPrompt", "OmostPrompt", diff --git a/diffsynth/extensions/ESRGAN/__init__.py b/diffsynth/extensions/ESRGAN/__init__.py index 00b90d1..94aff4c 100644 --- a/diffsynth/extensions/ESRGAN/__init__.py +++ b/diffsynth/extensions/ESRGAN/__init__.py @@ -107,6 +107,12 @@ class ESRGAN(torch.nn.Module): @torch.no_grad() def upscale(self, images, batch_size=4, progress_bar=lambda x:x): + if not isinstance(images, list): + images = [images] + is_single_image = True + else: + is_single_image = False + # Preprocess input_tensor = self.process_images(images) @@ -126,4 +132,6 @@ class ESRGAN(torch.nn.Module): # To images output_images = self.decode_images(output_tensor) + if is_single_image: + output_images = output_images[0] return output_images diff --git a/diffsynth/models/downloader.py b/diffsynth/models/downloader.py index 6801d71..6c726f6 100644 --- a/diffsynth/models/downloader.py +++ b/diffsynth/models/downloader.py @@ -8,28 +8,27 @@ from ..configs.model_config import preset_models_on_huggingface, preset_models_o def download_from_modelscope(model_id, origin_file_path, local_dir): os.makedirs(local_dir, exist_ok=True) - if os.path.basename(origin_file_path) in os.listdir(local_dir): - print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.") - return + file_name = os.path.basename(origin_file_path) + if file_name in os.listdir(local_dir): + print(f" {file_name} has been already in {local_dir}.") else: - print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") - snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir) - downloaded_file_path = os.path.join(local_dir, origin_file_path) - target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1]) - if downloaded_file_path != target_file_path: - shutil.move(downloaded_file_path, target_file_path) - shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0])) + print(f" Start downloading {os.path.join(local_dir, file_name)}") + snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir) + downloaded_file_path = os.path.join(local_dir, origin_file_path) + target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1]) + if downloaded_file_path != target_file_path: + shutil.move(downloaded_file_path, target_file_path) + shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0])) def download_from_huggingface(model_id, origin_file_path, local_dir): os.makedirs(local_dir, exist_ok=True) file_name = os.path.basename(origin_file_path) if file_name in os.listdir(local_dir): - return f"{file_name} has already been downloaded to {local_dir}." + print(f" {file_name} has been already in {local_dir}.") else: - print(f"Start downloading {os.path.join(local_dir, file_name)}") + print(f" Start downloading {os.path.join(local_dir, file_name)}") hf_hub_download(model_id, origin_file_path, local_dir=local_dir) - downloaded_file_path = os.path.join(local_dir, origin_file_path) target_file_path = os.path.join(local_dir, file_name) if downloaded_file_path != target_file_path: @@ -51,16 +50,47 @@ website_to_download_fn = { } +def download_customized_models( + model_id, + origin_file_path, + local_dir, + downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"], +): + downloaded_files = [] + for website in downloading_priority: + # Check if the file is downloaded. + file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path)) + if file_to_download in downloaded_files: + continue + # Download + website_to_download_fn[website](model_id, origin_file_path, local_dir) + if os.path.basename(origin_file_path) in os.listdir(local_dir): + downloaded_files.append(file_to_download) + return downloaded_files + + def download_models( model_id_list: List[Preset_model_id] = [], downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"], ): print(f"Downloading models: {model_id_list}") downloaded_files = [] + load_files = [] + for model_id in model_id_list: for website in downloading_priority: if model_id in website_to_preset_models[website]: - for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]: + + # Parse model metadata + model_metadata = website_to_preset_models[website][model_id] + if isinstance(model_metadata, list): + file_data = model_metadata + else: + file_data = model_metadata.get("file_list", []) + + # Try downloading the model from this website. + model_files = [] + for model_id, origin_file_path, local_dir in file_data: # Check if the file is downloaded. file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path)) if file_to_download in downloaded_files: @@ -69,4 +99,13 @@ def download_models( website_to_download_fn[website](model_id, origin_file_path, local_dir) if os.path.basename(origin_file_path) in os.listdir(local_dir): downloaded_files.append(file_to_download) - return downloaded_files + model_files.append(file_to_download) + + # If the model is successfully downloaded, break. + if len(model_files) > 0: + if isinstance(model_metadata, dict) and "load_path" in model_metadata: + model_files = model_metadata["load_path"] + load_files.extend(model_files) + break + + return load_files diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 648d23f..ffe55b8 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -364,6 +364,7 @@ class FluxDiT(torch.nn.Module): conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) if self.guidance_embedder is not None: + guidance = guidance * 1000 conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) prompt_emb = self.context_embedder(prompt_emb) image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 7f5eef8..0e8a51e 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -4,7 +4,7 @@ from torch import Tensor from typing_extensions import Literal, TypeAlias from typing import List -from .downloader import download_models, Preset_model_id, Preset_model_website +from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website from .sd_text_encoder import SDTextEncoder from .sd_unet import SDUNet diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index 956e9ba..55cfc14 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -65,9 +65,11 @@ class BasePipeline(torch.nn.Module): mask_scales += [100.0] * len(extended_prompt_dict.get("masks", [])) return prompt, local_prompts, masks, mask_scales + def enable_cpu_offload(self): self.cpu_offload = True + def load_models_to_device(self, loadmodel_names=[]): # only load models to device if cpu_offload is enabled if not self.cpu_offload: @@ -85,3 +87,9 @@ class BasePipeline(torch.nn.Module): model.to(self.device) # fresh the cuda cache torch.cuda.empty_cache() + + + def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16): + generator = None if seed is None else torch.Generator(device).manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) + return noise diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 5cd57f1..06f5649 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -58,14 +58,14 @@ class FluxImagePipeline(BasePipeline): return image - def encode_prompt(self, prompt, positive=True, t5_sequence_length=256): + def encode_prompt(self, prompt, positive=True, t5_sequence_length=512): prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt( prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length ) return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} - def prepare_extra_input(self, latents=None, guidance=0.0): + def prepare_extra_input(self, latents=None, guidance=1.0): latent_image_ids = self.dit.prepare_image_ids(latents) guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) return {"image_ids": latent_image_ids, "guidance": guidance} @@ -80,16 +80,17 @@ class FluxImagePipeline(BasePipeline): mask_scales= None, negative_prompt="", cfg_scale=1.0, - embedded_guidance=0.0, + embedded_guidance=3.5, input_image=None, denoising_strength=1.0, height=1024, width=1024, num_inference_steps=30, - t5_sequence_length=256, + t5_sequence_length=512, tiled=False, tile_size=128, tile_stride=64, + seed=None, progress_bar_cmd=tqdm, progress_bar_st=None, ): @@ -104,10 +105,10 @@ class FluxImagePipeline(BasePipeline): self.load_models_to_device(['vae_encoder']) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) latents = self.encode_image(image, **tiler_kwargs) - noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) + noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: - latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) + latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) # Extend prompt self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) diff --git a/diffsynth/prompters/flux_prompter.py b/diffsynth/prompters/flux_prompter.py index 07d6dc7..9a6bd7d 100644 --- a/diffsynth/prompters/flux_prompter.py +++ b/diffsynth/prompters/flux_prompter.py @@ -57,7 +57,7 @@ class FluxPrompter(BasePrompter): prompt, positive=True, device="cuda", - t5_sequence_length=256, + t5_sequence_length=512, ): prompt = self.process_prompt(prompt, positive=positive) diff --git a/examples/ExVideo/ExVideo_cogvideox_test.py b/examples/ExVideo/ExVideo_cogvideox_test.py new file mode 100644 index 0000000..4d0fd3a --- /dev/null +++ b/examples/ExVideo/ExVideo_cogvideox_test.py @@ -0,0 +1,21 @@ +from diffsynth import ModelManager, CogVideoPipeline, save_video, download_models +import torch + + +download_models(["CogVideoX-5B", "ExVideo-CogVideoX-LoRA-129f-v1"]) +model_manager = ModelManager(torch_dtype=torch.bfloat16) +model_manager.load_models([ + "models/CogVideo/CogVideoX-5b/text_encoder", + "models/CogVideo/CogVideoX-5b/transformer", + "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors", +]) +model_manager.load_lora("models/lora/ExVideo-CogVideoX-LoRA-129f-v1.safetensors") +pipe = CogVideoPipeline.from_model_manager(model_manager) + +torch.manual_seed(6) +video = pipe( + prompt="an astronaut riding a horse on Mars.", + height=480, width=720, num_frames=129, + cfg_scale=7.0, num_inference_steps=100, +) +save_video(video, "video_with_lora.mp4", fps=8, quality=5) diff --git a/examples/ExVideo/README.md b/examples/ExVideo/README.md index cb57d48..32f8d80 100644 --- a/examples/ExVideo/README.md +++ b/examples/ExVideo/README.md @@ -4,11 +4,19 @@ ExVideo is a post-tuning technique aimed at enhancing the capability of video ge * [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/) * [Technical report](https://arxiv.org/abs/2406.14130) -* [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) -* Extended models +* **[New]** Extended models (ExVideo-CogVideoX) + * [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) + * [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) +* Extended models (ExVideo-SVD) * [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) * [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1) +## Example: Text-to-video via extended CogVideoX-5B + +Generate a video using CogVideoX-5B and our extension module. See [ExVideo_cogvideox_test.py](./ExVideo_cogvideox_test.py). + +https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e + ## Example: Text-to-video via extended Stable Video Diffusion Generate a video using a text-to-image model and our image-to-video model. See [ExVideo_svd_test.py](./ExVideo_svd_test.py). diff --git a/examples/image_synthesis/README.md b/examples/image_synthesis/README.md index d2a52c5..65dbe66 100644 --- a/examples/image_synthesis/README.md +++ b/examples/image_synthesis/README.md @@ -10,7 +10,7 @@ The original version of FLUX doesn't support classifier-free guidance; however, |1024*1024 (original)|1024*1024 (classifier-free guidance)|2048*2048 (highres-fix)| |-|-|-| -|![image_1024](https://github.com/user-attachments/assets/ce01327f-068f-45f5-aba9-0fa45eb26199)|![image_1024_cfg](https://github.com/user-attachments/assets/6af5b106-0673-4e58-9213-cd9157eef4c0)|![image_2048_highres](https://github.com/user-attachments/assets/a4bb776f-d9f0-4450-968c-c5d090a3ab4c)| +|![image_1024](https://github.com/user-attachments/assets/9cbd1f6f-4ac4-4f8b-bf46-218d812a15a0)|![image_1024_cfg](https://github.com/user-attachments/assets/984561e9-553d-4952-9443-79ce144f379f)|![image_2048_highres](https://github.com/user-attachments/assets/2e92b2f8-c177-454f-84f6-f6f5d3aaeeff)| ### Example: Stable Diffusion diff --git a/examples/image_synthesis/flux_text_to_image.py b/examples/image_synthesis/flux_text_to_image.py index a2e5199..6a50df3 100644 --- a/examples/image_synthesis/flux_text_to_image.py +++ b/examples/image_synthesis/flux_text_to_image.py @@ -12,30 +12,30 @@ model_manager.load_models([ ]) pipe = FluxImagePipeline.from_model_manager(model_manager) -prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." -negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," # Disable classifier-free guidance (consistent with the original implementation of FLUX.1) -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5 + num_inference_steps=50, embedded_guidance=3.5 ) image.save("image_1024.jpg") # Enable classifier-free guidance -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, negative_prompt=negative_prompt, - num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5 + num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5 ) image.save("image_1024_cfg.jpg") # Highres-fix -torch.manual_seed(7) +torch.manual_seed(10) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5, + num_inference_steps=50, embedded_guidance=3.5, input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True ) image.save("image_2048_highres.jpg") diff --git a/examples/image_synthesis/flux_text_to_image_low_vram.py b/examples/image_synthesis/flux_text_to_image_low_vram.py index b98929c..985f009 100644 --- a/examples/image_synthesis/flux_text_to_image_low_vram.py +++ b/examples/image_synthesis/flux_text_to_image_low_vram.py @@ -22,30 +22,30 @@ pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda") pipe.enable_cpu_offload() pipe.dit.quantize() -prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." -negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," # Disable classifier-free guidance (consistent with the original implementation of FLUX.1) -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5 + num_inference_steps=50, embedded_guidance=3.5 ) image.save("image_1024.jpg") # Enable classifier-free guidance -torch.manual_seed(6) +torch.manual_seed(9) image = pipe( prompt=prompt, negative_prompt=negative_prompt, - num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5 + num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5 ) image.save("image_1024_cfg.jpg") # Highres-fix -torch.manual_seed(7) +torch.manual_seed(10) image = pipe( prompt=prompt, - num_inference_steps=30, embedded_guidance=3.5, + num_inference_steps=50, embedded_guidance=3.5, input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True ) -image.save("image_2048_highres.jpg") \ No newline at end of file +image.save("image_2048_highres.jpg")