Merge pull request #234 from modelscope/doc-patch

Patch
This commit is contained in:
Zhongjie Duan
2024-10-10 19:17:00 +08:00
committed by GitHub
14 changed files with 263 additions and 116 deletions

View File

@@ -139,10 +139,11 @@ https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
#### Long Video Synthesis #### Long Video Synthesis
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/) We trained extended video synthesis models, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e
#### Toon Shading #### Toon Shading
@@ -166,7 +167,7 @@ LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|FLUX|Stable Diffusion 3| |FLUX|Stable Diffusion 3|
|-|-| |-|-|
|![image_1024_cfg](https://github.com/user-attachments/assets/6af5b106-0673-4e58-9213-cd9157eef4c0)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)| |![image_1024_cfg](https://github.com/user-attachments/assets/984561e9-553d-4952-9443-79ce144f379f)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
|Kolors|Hunyuan-DiT| |Kolors|Hunyuan-DiT|
|-|-| |-|-|

View File

@@ -353,47 +353,67 @@ preset_models_on_modelscope = {
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"), ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
], ],
# Qwen Prompt # Qwen Prompt
"QwenPrompt": [ "QwenPrompt": {
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), "file_list": [
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"), ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"), ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"), ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"), ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"), ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"), ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
], ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
],
"load_path": [
"models/QwenPrompt/qwen2-1.5b-instruct",
],
},
# Beautiful Prompt # Beautiful Prompt
"BeautifulPrompt": [ "BeautifulPrompt": {
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), "file_list": [
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
], ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
],
"load_path": [
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
],
},
# Omost prompt # Omost prompt
"OmostPrompt":[ "OmostPrompt": {
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), "file_list": [
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
], ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
],
"load_path": [
"models/OmostPrompt/omost-llama-3-8b-4bits",
],
},
# Translator # Translator
"opus-mt-zh-en": [ "opus-mt-zh-en": {
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), "file_list": [
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"), ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
], ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
],
"load_path": [
"models/translator/opus-mt-zh-en",
],
},
# IP-Adapter # IP-Adapter
"IP-Adapter-SD": [ "IP-Adapter-SD": [
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"), ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
@@ -404,32 +424,64 @@ preset_models_on_modelscope = {
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"), ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
], ],
# Kolors # Kolors
"Kolors": [ "Kolors": {
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"), "file_list": [
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"), ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"), ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
], ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
"load_path": [
"models/kolors/Kolors/text_encoder",
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
],
},
"SDXL-vae-fp16-fix": [ "SDXL-vae-fp16-fix": [
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix") ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
], ],
# FLUX # FLUX
"FLUX.1-dev": [ "FLUX.1-dev": {
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"), "file_list": [
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"), ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"), ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
], ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
],
"load_path": [
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
],
},
"FLUX.1-schnell": {
"file_list": [
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
],
"load_path": [
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
],
},
# ESRGAN # ESRGAN
"ESRGAN_x4": [ "ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"), ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
@@ -439,17 +491,24 @@ preset_models_on_modelscope = {
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"), ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
], ],
# CogVideo # CogVideo
"CogVideoX-5B": [ "CogVideoX-5B": {
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"), "file_list": [
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"), ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"), ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"), ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"), ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
], ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
],
"load_path": [
"models/CogVideo/CogVideoX-5b/text_encoder",
"models/CogVideo/CogVideoX-5b/transformer",
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
],
},
} }
Preset_model_id: TypeAlias = Literal[ Preset_model_id: TypeAlias = Literal[
"HunyuanDiT", "HunyuanDiT",
@@ -481,6 +540,7 @@ Preset_model_id: TypeAlias = Literal[
"SDXL-vae-fp16-fix", "SDXL-vae-fp16-fix",
"ControlNet_union_sdxl_promax", "ControlNet_union_sdxl_promax",
"FLUX.1-dev", "FLUX.1-dev",
"FLUX.1-schnell",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt", "QwenPrompt",
"OmostPrompt", "OmostPrompt",

View File

@@ -107,6 +107,12 @@ class ESRGAN(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def upscale(self, images, batch_size=4, progress_bar=lambda x:x): def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
if not isinstance(images, list):
images = [images]
is_single_image = True
else:
is_single_image = False
# Preprocess # Preprocess
input_tensor = self.process_images(images) input_tensor = self.process_images(images)
@@ -126,4 +132,6 @@ class ESRGAN(torch.nn.Module):
# To images # To images
output_images = self.decode_images(output_tensor) output_images = self.decode_images(output_tensor)
if is_single_image:
output_images = output_images[0]
return output_images return output_images

View File

@@ -8,28 +8,27 @@ from ..configs.model_config import preset_models_on_huggingface, preset_models_o
def download_from_modelscope(model_id, origin_file_path, local_dir): def download_from_modelscope(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True) os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir): file_name = os.path.basename(origin_file_path)
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.") if file_name in os.listdir(local_dir):
return print(f" {file_name} has been already in {local_dir}.")
else: else:
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") print(f" Start downloading {os.path.join(local_dir, file_name)}")
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir) snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path) downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1]) target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
if downloaded_file_path != target_file_path: if downloaded_file_path != target_file_path:
shutil.move(downloaded_file_path, target_file_path) shutil.move(downloaded_file_path, target_file_path)
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0])) shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
def download_from_huggingface(model_id, origin_file_path, local_dir): def download_from_huggingface(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True) os.makedirs(local_dir, exist_ok=True)
file_name = os.path.basename(origin_file_path) file_name = os.path.basename(origin_file_path)
if file_name in os.listdir(local_dir): if file_name in os.listdir(local_dir):
return f"{file_name} has already been downloaded to {local_dir}." print(f" {file_name} has been already in {local_dir}.")
else: else:
print(f"Start downloading {os.path.join(local_dir, file_name)}") print(f" Start downloading {os.path.join(local_dir, file_name)}")
hf_hub_download(model_id, origin_file_path, local_dir=local_dir) hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path) downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, file_name) target_file_path = os.path.join(local_dir, file_name)
if downloaded_file_path != target_file_path: if downloaded_file_path != target_file_path:
@@ -51,16 +50,47 @@ website_to_download_fn = {
} }
def download_customized_models(
model_id,
origin_file_path,
local_dir,
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
downloaded_files = []
for website in downloading_priority:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
return downloaded_files
def download_models( def download_models(
model_id_list: List[Preset_model_id] = [], model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"], downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
): ):
print(f"Downloading models: {model_id_list}") print(f"Downloading models: {model_id_list}")
downloaded_files = [] downloaded_files = []
load_files = []
for model_id in model_id_list: for model_id in model_id_list:
for website in downloading_priority: for website in downloading_priority:
if model_id in website_to_preset_models[website]: if model_id in website_to_preset_models[website]:
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
# Parse model metadata
model_metadata = website_to_preset_models[website][model_id]
if isinstance(model_metadata, list):
file_data = model_metadata
else:
file_data = model_metadata.get("file_list", [])
# Try downloading the model from this website.
model_files = []
for model_id, origin_file_path, local_dir in file_data:
# Check if the file is downloaded. # Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path)) file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files: if file_to_download in downloaded_files:
@@ -69,4 +99,13 @@ def download_models(
website_to_download_fn[website](model_id, origin_file_path, local_dir) website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir): if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download) downloaded_files.append(file_to_download)
return downloaded_files model_files.append(file_to_download)
# If the model is successfully downloaded, break.
if len(model_files) > 0:
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
model_files = model_metadata["load_path"]
load_files.extend(model_files)
break
return load_files

View File

@@ -364,6 +364,7 @@ class FluxDiT(torch.nn.Module):
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None: if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb) prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))

View File

@@ -4,7 +4,7 @@ from torch import Tensor
from typing_extensions import Literal, TypeAlias from typing_extensions import Literal, TypeAlias
from typing import List from typing import List
from .downloader import download_models, Preset_model_id, Preset_model_website from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
from .sd_text_encoder import SDTextEncoder from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet from .sd_unet import SDUNet

View File

@@ -65,9 +65,11 @@ class BasePipeline(torch.nn.Module):
mask_scales += [100.0] * len(extended_prompt_dict.get("masks", [])) mask_scales += [100.0] * len(extended_prompt_dict.get("masks", []))
return prompt, local_prompts, masks, mask_scales return prompt, local_prompts, masks, mask_scales
def enable_cpu_offload(self): def enable_cpu_offload(self):
self.cpu_offload = True self.cpu_offload = True
def load_models_to_device(self, loadmodel_names=[]): def load_models_to_device(self, loadmodel_names=[]):
# only load models to device if cpu_offload is enabled # only load models to device if cpu_offload is enabled
if not self.cpu_offload: if not self.cpu_offload:
@@ -85,3 +87,9 @@ class BasePipeline(torch.nn.Module):
model.to(self.device) model.to(self.device)
# fresh the cuda cache # fresh the cuda cache
torch.cuda.empty_cache() torch.cuda.empty_cache()
def generate_noise(self, shape, seed=None, device="cpu", dtype=torch.float16):
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
return noise

View File

@@ -58,14 +58,14 @@ class FluxImagePipeline(BasePipeline):
return image return image
def encode_prompt(self, prompt, positive=True, t5_sequence_length=256): def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):
prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt( prompt_emb, pooled_prompt_emb, text_ids = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length prompt, device=self.device, positive=positive, t5_sequence_length=t5_sequence_length
) )
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids} return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_ids": text_ids}
def prepare_extra_input(self, latents=None, guidance=0.0): def prepare_extra_input(self, latents=None, guidance=1.0):
latent_image_ids = self.dit.prepare_image_ids(latents) latent_image_ids = self.dit.prepare_image_ids(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"image_ids": latent_image_ids, "guidance": guidance} return {"image_ids": latent_image_ids, "guidance": guidance}
@@ -80,16 +80,17 @@ class FluxImagePipeline(BasePipeline):
mask_scales= None, mask_scales= None,
negative_prompt="", negative_prompt="",
cfg_scale=1.0, cfg_scale=1.0,
embedded_guidance=0.0, embedded_guidance=3.5,
input_image=None, input_image=None,
denoising_strength=1.0, denoising_strength=1.0,
height=1024, height=1024,
width=1024, width=1024,
num_inference_steps=30, num_inference_steps=30,
t5_sequence_length=256, t5_sequence_length=512,
tiled=False, tiled=False,
tile_size=128, tile_size=128,
tile_stride=64, tile_stride=64,
seed=None,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
@@ -104,10 +105,10 @@ class FluxImagePipeline(BasePipeline):
self.load_models_to_device(['vae_encoder']) self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs) latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else: else:
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Extend prompt # Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])

View File

@@ -57,7 +57,7 @@ class FluxPrompter(BasePrompter):
prompt, prompt,
positive=True, positive=True,
device="cuda", device="cuda",
t5_sequence_length=256, t5_sequence_length=512,
): ):
prompt = self.process_prompt(prompt, positive=positive) prompt = self.process_prompt(prompt, positive=positive)

View File

@@ -0,0 +1,21 @@
from diffsynth import ModelManager, CogVideoPipeline, save_video, download_models
import torch
download_models(["CogVideoX-5B", "ExVideo-CogVideoX-LoRA-129f-v1"])
model_manager = ModelManager(torch_dtype=torch.bfloat16)
model_manager.load_models([
"models/CogVideo/CogVideoX-5b/text_encoder",
"models/CogVideo/CogVideoX-5b/transformer",
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
])
model_manager.load_lora("models/lora/ExVideo-CogVideoX-LoRA-129f-v1.safetensors")
pipe = CogVideoPipeline.from_model_manager(model_manager)
torch.manual_seed(6)
video = pipe(
prompt="an astronaut riding a horse on Mars.",
height=480, width=720, num_frames=129,
cfg_scale=7.0, num_inference_steps=100,
)
save_video(video, "video_with_lora.mp4", fps=8, quality=5)

View File

@@ -4,11 +4,19 @@ ExVideo is a post-tuning technique aimed at enhancing the capability of video ge
* [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/) * [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
* [Technical report](https://arxiv.org/abs/2406.14130) * [Technical report](https://arxiv.org/abs/2406.14130)
* [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) * **[New]** Extended models (ExVideo-CogVideoX)
* Extended models * [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1)
* [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1)
* Extended models (ExVideo-SVD)
* [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) * [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
* [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1) * [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)
## Example: Text-to-video via extended CogVideoX-5B
Generate a video using CogVideoX-5B and our extension module. See [ExVideo_cogvideox_test.py](./ExVideo_cogvideox_test.py).
https://github.com/user-attachments/assets/321ee04b-8c17-479e-8a95-8cbcf21f8d7e
## Example: Text-to-video via extended Stable Video Diffusion ## Example: Text-to-video via extended Stable Video Diffusion
Generate a video using a text-to-image model and our image-to-video model. See [ExVideo_svd_test.py](./ExVideo_svd_test.py). Generate a video using a text-to-image model and our image-to-video model. See [ExVideo_svd_test.py](./ExVideo_svd_test.py).

View File

@@ -10,7 +10,7 @@ The original version of FLUX doesn't support classifier-free guidance; however,
|1024*1024 (original)|1024*1024 (classifier-free guidance)|2048*2048 (highres-fix)| |1024*1024 (original)|1024*1024 (classifier-free guidance)|2048*2048 (highres-fix)|
|-|-|-| |-|-|-|
|![image_1024](https://github.com/user-attachments/assets/ce01327f-068f-45f5-aba9-0fa45eb26199)|![image_1024_cfg](https://github.com/user-attachments/assets/6af5b106-0673-4e58-9213-cd9157eef4c0)|![image_2048_highres](https://github.com/user-attachments/assets/a4bb776f-d9f0-4450-968c-c5d090a3ab4c)| |![image_1024](https://github.com/user-attachments/assets/9cbd1f6f-4ac4-4f8b-bf46-218d812a15a0)|![image_1024_cfg](https://github.com/user-attachments/assets/984561e9-553d-4952-9443-79ce144f379f)|![image_2048_highres](https://github.com/user-attachments/assets/2e92b2f8-c177-454f-84f6-f6f5d3aaeeff)|
### Example: Stable Diffusion ### Example: Stable Diffusion

View File

@@ -12,30 +12,30 @@ model_manager.load_models([
]) ])
pipe = FluxImagePipeline.from_model_manager(model_manager) pipe = FluxImagePipeline.from_model_manager(model_manager)
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1) # Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
torch.manual_seed(6) torch.manual_seed(9)
image = pipe( image = pipe(
prompt=prompt, prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5 num_inference_steps=50, embedded_guidance=3.5
) )
image.save("image_1024.jpg") image.save("image_1024.jpg")
# Enable classifier-free guidance # Enable classifier-free guidance
torch.manual_seed(6) torch.manual_seed(9)
image = pipe( image = pipe(
prompt=prompt, negative_prompt=negative_prompt, prompt=prompt, negative_prompt=negative_prompt,
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5 num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
) )
image.save("image_1024_cfg.jpg") image.save("image_1024_cfg.jpg")
# Highres-fix # Highres-fix
torch.manual_seed(7) torch.manual_seed(10)
image = pipe( image = pipe(
prompt=prompt, prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5, num_inference_steps=50, embedded_guidance=3.5,
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
) )
image.save("image_2048_highres.jpg") image.save("image_2048_highres.jpg")

View File

@@ -22,30 +22,30 @@ pipe = FluxImagePipeline.from_model_manager(model_manager, device="cuda")
pipe.enable_cpu_offload() pipe.enable_cpu_offload()
pipe.dit.quantize() pipe.dit.quantize()
prompt = "CG. Full body. A captivating fantasy magic woman portrait in the deep sea. The woman, with blue spaghetti strap silk dress, swims in the sea. Her flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her. Smooth, delicate and fair skin." prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
negative_prompt = "dark, worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, dim, fuzzy, depth of Field, nsfw," negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,"
# Disable classifier-free guidance (consistent with the original implementation of FLUX.1) # Disable classifier-free guidance (consistent with the original implementation of FLUX.1)
torch.manual_seed(6) torch.manual_seed(9)
image = pipe( image = pipe(
prompt=prompt, prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5 num_inference_steps=50, embedded_guidance=3.5
) )
image.save("image_1024.jpg") image.save("image_1024.jpg")
# Enable classifier-free guidance # Enable classifier-free guidance
torch.manual_seed(6) torch.manual_seed(9)
image = pipe( image = pipe(
prompt=prompt, negative_prompt=negative_prompt, prompt=prompt, negative_prompt=negative_prompt,
num_inference_steps=30, cfg_scale=2.0, embedded_guidance=3.5 num_inference_steps=50, cfg_scale=2.0, embedded_guidance=3.5
) )
image.save("image_1024_cfg.jpg") image.save("image_1024_cfg.jpg")
# Highres-fix # Highres-fix
torch.manual_seed(7) torch.manual_seed(10)
image = pipe( image = pipe(
prompt=prompt, prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5, num_inference_steps=50, embedded_guidance=3.5,
input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True input_image=image.resize((2048, 2048)), height=2048, width=2048, denoising_strength=0.6, tiled=True
) )
image.save("image_2048_highres.jpg") image.save("image_2048_highres.jpg")