rebuild base modules

This commit is contained in:
Artiprocher
2024-07-26 12:15:40 +08:00
parent 9471bff8a4
commit e3f8a576cf
76 changed files with 3253 additions and 3563 deletions

View File

@@ -1,6 +1,6 @@
from .data import *
from .models import *
from .prompts import *
from .prompters import *
from .schedulers import *
from .pipelines import *
from .controlnets import *

View File

@@ -0,0 +1,243 @@
from typing_extensions import Literal, TypeAlias
from ..models.sd_text_encoder import SDTextEncoder
from ..models.sd_unet import SDUNet
from ..models.sd_vae_encoder import SDVAEEncoder
from ..models.sd_vae_decoder import SDVAEDecoder
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from ..models.sdxl_unet import SDXLUNet
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from ..models.sd3_dit import SD3DiT
from ..models.sd3_vae_decoder import SD3VAEDecoder
from ..models.sd3_vae_encoder import SD3VAEEncoder
from ..models.sd_controlnet import SDControlNet
from ..models.sd_motion import SDMotionModel
from ..models.sdxl_motion import SDXLMotionModel
from ..models.svd_image_encoder import SVDImageEncoder
from ..models.svd_unet import SVDUNet
from ..models.svd_vae_decoder import SVDVAEDecoder
from ..models.svd_vae_encoder import SVDVAEEncoder
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from ..models.hunyuan_dit import HunyuanDiT
model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name)
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder"),
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator"),
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt"),
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
]
preset_models_on_huggingface = {
"HunyuanDiT": [
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
"stable-video-diffusion-img2vid-xt": [
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
}
preset_models_on_modelscope = {
# Hunyuan DiT
"HunyuanDiT": [
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
# Stable Video Diffusion
"stable-video-diffusion-img2vid-xt": [
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
# ExVideo
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
# Stable Diffusion
"StableDiffusion_v15": [
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
],
"DreamShaper_8": [
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
],
"AingDiffusion_v12": [
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
],
"Flat2DAnimerge_v45Sharp": [
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
],
# Textual Inversion
"TextualInversion_VeryBadImageNegative_v1.3": [
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
],
# Stable Diffusion XL
"StableDiffusionXL_v1": [
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
],
"BluePencilXL_v200": [
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
],
"StableDiffusionXL_Turbo": [
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
],
# Stable Diffusion 3
"StableDiffusion3": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
],
"StableDiffusion3_without_T5": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
],
# ControlNet
"ControlNet_v11f1p_sd15_depth": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"ControlNet_v11p_sd15_softedge": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
],
"ControlNet_v11f1e_sd15_tile": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
],
"ControlNet_v11p_sd15_lineart": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
],
# AnimateDiff
"AnimateDiff_v2": [
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
],
"AnimateDiff_xl_beta": [
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
],
# RIFE
"RIFE": [
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
],
# Beautiful Prompt
"BeautifulPrompt": [
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
],
# Translator
"opus-mt-zh-en": [
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
],
# IP-Adapter
"IP-Adapter-SD": [
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
],
"IP-Adapter-SDXL": [
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
"SDXL-vae-fp16-fix": [
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
],
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
"stable-video-diffusion-img2vid-xt",
"ExVideo-SVD-128f-v1",
"StableDiffusion_v15",
"DreamShaper_8",
"AingDiffusion_v12",
"Flat2DAnimerge_v45Sharp",
"TextualInversion_VeryBadImageNegative_v1.3",
"StableDiffusionXL_v1",
"BluePencilXL_v200",
"StableDiffusionXL_Turbo",
"ControlNet_v11f1p_sd15_depth",
"ControlNet_v11p_sd15_softedge",
"ControlNet_v11f1e_sd15_tile",
"ControlNet_v11p_sd15_lineart",
"AnimateDiff_v2",
"AnimateDiff_xl_beta",
"RIFE",
"BeautifulPrompt",
"opus-mt-zh-en",
"IP-Adapter-SD",
"IP-Adapter-SDXL",
"StableDiffusion3",
"StableDiffusion3_without_T5",
"Kolors",
"SDXL-vae-fp16-fix",
]

View File

@@ -0,0 +1,35 @@
import torch, os
from torchvision import transforms
import pandas as pd
from PIL import Image
class TextImageDataset(torch.utils.data.Dataset):
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
self.steps_per_epoch = steps_per_epoch
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.image_processor = transforms.Compose(
[
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
text = self.text[data_id]
image = Image.open(self.path[data_id]).convert("RGB")
image = self.image_processor(image)
return {"text": text, "image": image}
def __len__(self):
return self.steps_per_epoch

View File

@@ -99,7 +99,8 @@ class IFNet(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
return flow_list, mask_list[2], merged
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return IFNetStateDictConverter()

View File

@@ -1,814 +1 @@
import torch, os, json
from safetensors import safe_open
from typing_extensions import Literal, TypeAlias
from typing import List
from .downloader import download_from_huggingface, download_from_modelscope
from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .sd_lora import SDLoRA
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from .sd3_dit import SD3DiT
from .sd3_vae_decoder import SD3VAEDecoder
from .sd3_vae_encoder import SD3VAEEncoder
from .sd_controlnet import SDControlNet
from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel
from .svd_image_encoder import SVDImageEncoder
from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .kolors_text_encoder import ChatGLMModel
preset_models_on_huggingface = {
"HunyuanDiT": [
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
"stable-video-diffusion-img2vid-xt": [
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
}
preset_models_on_modelscope = {
# Hunyuan DiT
"HunyuanDiT": [
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
],
# Stable Video Diffusion
"stable-video-diffusion-img2vid-xt": [
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
],
# ExVideo
"ExVideo-SVD-128f-v1": [
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
],
# Stable Diffusion
"StableDiffusion_v15": [
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
],
"DreamShaper_8": [
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
],
"AingDiffusion_v12": [
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
],
"Flat2DAnimerge_v45Sharp": [
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
],
# Textual Inversion
"TextualInversion_VeryBadImageNegative_v1.3": [
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
],
# Stable Diffusion XL
"StableDiffusionXL_v1": [
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
],
"BluePencilXL_v200": [
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
],
"StableDiffusionXL_Turbo": [
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
],
# Stable Diffusion 3
"StableDiffusion3": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
],
"StableDiffusion3_without_T5": [
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
],
# ControlNet
"ControlNet_v11f1p_sd15_depth": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
],
"ControlNet_v11p_sd15_softedge": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
],
"ControlNet_v11f1e_sd15_tile": [
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
],
"ControlNet_v11p_sd15_lineart": [
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
],
# AnimateDiff
"AnimateDiff_v2": [
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
],
"AnimateDiff_xl_beta": [
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
],
# RIFE
"RIFE": [
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
],
# Beautiful Prompt
"BeautifulPrompt": [
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
],
# Translator
"opus-mt-zh-en": [
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
],
# IP-Adapter
"IP-Adapter-SD": [
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
],
"IP-Adapter-SDXL": [
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
"SDXL-vae-fp16-fix": [
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
],
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
"stable-video-diffusion-img2vid-xt",
"ExVideo-SVD-128f-v1",
"StableDiffusion_v15",
"DreamShaper_8",
"AingDiffusion_v12",
"Flat2DAnimerge_v45Sharp",
"TextualInversion_VeryBadImageNegative_v1.3",
"StableDiffusionXL_v1",
"BluePencilXL_v200",
"StableDiffusionXL_Turbo",
"ControlNet_v11f1p_sd15_depth",
"ControlNet_v11p_sd15_softedge",
"ControlNet_v11f1e_sd15_tile",
"ControlNet_v11p_sd15_lineart",
"AnimateDiff_v2",
"AnimateDiff_xl_beta",
"RIFE",
"BeautifulPrompt",
"opus-mt-zh-en",
"IP-Adapter-SD",
"IP-Adapter-SDXL",
"StableDiffusion3",
"StableDiffusion3_without_T5",
"Kolors",
"SDXL-vae-fp16-fix",
]
Preset_model_website: TypeAlias = Literal[
"HuggingFace",
"ModelScope",
]
website_to_preset_models = {
"HuggingFace": preset_models_on_huggingface,
"ModelScope": preset_models_on_modelscope,
}
website_to_download_fn = {
"HuggingFace": download_from_huggingface,
"ModelScope": download_from_modelscope,
}
def download_models(
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
downloaded_files = []
for model_id in model_id_list:
for website in downloading_priority:
if model_id in website_to_preset_models[website]:
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
return downloaded_files
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = {}
self.model_path = {}
self.textual_inversion_dict = {}
downloaded_files = download_models(model_id_list, downloading_priority)
self.load_models(downloaded_files + file_path_list)
def load_model_from_origin(
self,
download_from: Preset_model_website = "ModelScope",
model_id = "",
origin_file_path = "",
local_dir = ""
):
website_to_download_fn[download_from](model_id, origin_file_path, local_dir)
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
self.load_model(file_to_download)
def is_stable_video_diffusion(self, state_dict):
param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
return param_name in state_dict
def is_RIFE(self, state_dict):
param_name = "block_tea.convblock3.0.1.weight"
return param_name in state_dict or ("module." + param_name) in state_dict
def is_beautiful_prompt(self, state_dict):
param_name = "transformer.h.9.self_attention.query_key_value.weight"
return param_name in state_dict
def is_stabe_diffusion_xl(self, state_dict):
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
return param_name in state_dict
def is_stable_diffusion(self, state_dict):
if self.is_stabe_diffusion_xl(state_dict):
return False
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
return param_name in state_dict
def is_controlnet(self, state_dict):
param_name = "control_model.time_embed.0.weight"
return param_name in state_dict
def is_animatediff(self, state_dict):
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
return param_name in state_dict
def is_animatediff_xl(self, state_dict):
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
return param_name in state_dict
def is_sd_lora(self, state_dict):
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
return param_name in state_dict
def is_translator(self, state_dict):
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
return param_name in state_dict and len(state_dict) == 258
def is_ipadapter(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
def is_ipadapter_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
return param_name in state_dict and len(state_dict) == 521
def is_ipadapter_xl(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
def is_ipadapter_xl_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
return param_name in state_dict and len(state_dict) == 777
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
return param_name in state_dict
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
param_name_ = "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
return param_name in state_dict and param_name_ in state_dict
def is_hunyuan_dit(self, state_dict):
param_name = "final_layer.adaLN_modulation.1.weight"
return param_name in state_dict
def is_diffusers_vae(self, state_dict):
param_name = "quant_conv.weight"
return param_name in state_dict
def is_ExVideo_StableVideoDiffusion(self, state_dict):
param_name = "blocks.185.positional_embedding.embeddings"
return param_name in state_dict
def is_stable_diffusion_3(self, state_dict):
param_names = [
"text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
"text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
"model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight",
"first_stage_model.encoder.mid.block_2.norm2.weight",
"first_stage_model.decoder.mid.block_2.norm2.weight",
]
for param_name in param_names:
if param_name not in state_dict:
return False
return True
def is_stable_diffusion_3_t5(self, state_dict):
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
return param_name in state_dict
def is_kolors_text_encoder(self, file_path):
file_list = os.listdir(file_path)
if "config.json" in file_list:
try:
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if config.get("model_type") == "chatglm":
return True
except:
pass
return False
def is_kolors_unet(self, state_dict):
return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
component_dict = {
"image_encoder": SVDImageEncoder,
"unet": SVDUNet,
"vae_decoder": SVDVAEDecoder,
"vae_encoder": SVDVAEEncoder,
}
if components is None:
components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
if component == "unet":
self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDTextEncoder,
"unet": SDUNet,
"vae_decoder": SDVAEDecoder,
"vae_encoder": SDVAEEncoder,
"refiner": SDXLUNet,
}
if components is None:
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
for component in components:
if component == "text_encoder":
# Add additional token embeddings to text encoder
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
for keyword in self.textual_inversion_dict:
_, embeddings = self.textual_inversion_dict[keyword]
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
token_embeddings = torch.concat(token_embeddings, dim=0)
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
component_dict = {
"text_encoder": SDXLTextEncoder,
"text_encoder_2": SDXLTextEncoder2,
"unet": SDXLUNet,
"vae_decoder": SDXLVAEDecoder,
"vae_encoder": SDXLVAEEncoder,
}
if components is None:
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
for component in components:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
if component in ["vae_decoder", "vae_encoder"]:
# These two model will output nan when float16 is enabled.
# The precision problem happens in the last three resnet blocks.
# I do not know how to solve this problem.
self.model[component].to(torch.float32).to(self.device)
else:
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_controlnet(self, state_dict, file_path=""):
component = "controlnet"
if component not in self.model:
self.model[component] = []
self.model_path[component] = []
model = SDControlNet()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component].append(model)
self.model_path[component].append(file_path)
def load_animatediff(self, state_dict, file_path=""):
component = "motion_modules"
model = SDMotionModel()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_animatediff_xl(self, state_dict, file_path=""):
component = "motion_modules_xl"
model = SDXLMotionModel()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_beautiful_prompt(self, state_dict, file_path=""):
component = "beautiful_prompt"
from transformers import AutoModelForCausalLM
model_folder = os.path.dirname(file_path)
model = AutoModelForCausalLM.from_pretrained(
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
).to(self.device).eval()
self.model[component] = model
self.model_path[component] = file_path
def load_RIFE(self, state_dict, file_path=""):
component = "RIFE"
from ..extensions.RIFE import IFNet
model = IFNet().eval()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_sd_lora(self, state_dict, alpha):
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
def load_translator(self, state_dict, file_path=""):
# This model is lightweight, we do not place it on GPU.
component = "translator"
from transformers import AutoModelForSeq2SeqLM
model_folder = os.path.dirname(file_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter(self, state_dict, file_path=""):
component = "ipadapter"
model = SDIpAdapter()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_image_encoder"
model = IpAdapterCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl(self, state_dict, file_path=""):
component = "ipadapter_xl"
model = SDXLIpAdapter()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_xl_image_encoder"
model = IpAdapterXLCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_clip_text_encoder"
model = HunyuanDiTCLIPTextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_t5_text_encoder"
model = HunyuanDiTT5TextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit(self, state_dict, file_path=""):
component = "hunyuan_dit"
model = HunyuanDiT()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_diffusers_vae(self, state_dict, file_path=""):
# TODO: detect SD and SDXL
component = "vae_encoder"
model = SDXLVAEEncoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
component = "vae_decoder"
model = SDXLVAEDecoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
unet_state_dict = self.model["unet"].state_dict()
self.model["unet"].to("cpu")
del self.model["unet"]
add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
self.model["unet"].load_state_dict(unet_state_dict, strict=False)
self.model["unet"].load_state_dict(state_dict, strict=False)
self.model["unet"].to(self.torch_dtype).to(self.device)
def load_stable_diffusion_3(self, state_dict, components=None, file_path=""):
component_dict = {
"sd3_text_encoder_1": SD3TextEncoder1,
"sd3_text_encoder_2": SD3TextEncoder2,
"sd3_text_encoder_3": SD3TextEncoder3,
"sd3_dit": SD3DiT,
"sd3_vae_decoder": SD3VAEDecoder,
"sd3_vae_encoder": SD3VAEEncoder,
}
if components is None:
components = ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_decoder", "sd3_vae_encoder"]
for component in components:
if component == "sd3_text_encoder_3":
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
continue
if component == "sd3_text_encoder_1":
# Add additional token embeddings to text encoder
token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
for keyword in self.textual_inversion_dict:
_, embeddings = self.textual_inversion_dict[keyword]
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
token_embeddings = torch.concat(token_embeddings, dim=0)
state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
else:
self.model[component] = component_dict[component]()
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
self.model[component].to(self.torch_dtype).to(self.device)
self.model_path[component] = file_path
def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
component = "sd3_text_encoder_3"
model = SD3TextEncoder3()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_kolors_text_encoder(self, state_dict=None, file_path=""):
component = "kolors_text_encoder"
model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype)
model = model.to(dtype=self.torch_dtype, device=self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_kolors_unet(self, state_dict, file_path=""):
component = "kolors_unet"
model = SDXLUNet(is_kolors=True)
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += self.search_for_embeddings(state_dict[k])
return embeddings
def load_textual_inversions(self, folder):
# Store additional tokens here
self.textual_inversion_dict = {}
# Load every textual inversion file
for file_name in os.listdir(folder):
if os.path.isdir(os.path.join(folder, file_name)) or \
not (file_name.endswith(".bin") or \
file_name.endswith(".safetensors") or \
file_name.endswith(".pth") or \
file_name.endswith(".pt")):
continue
keyword = os.path.splitext(file_name)[0]
state_dict = load_state_dict(os.path.join(folder, file_name))
# Search for embeddings
for embeddings in self.search_for_embeddings(state_dict):
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
self.textual_inversion_dict[keyword] = (tokens, embeddings)
break
def load_model(self, file_path, components=None, lora_alphas=[]):
if os.path.isdir(file_path):
if self.is_kolors_text_encoder(file_path):
self.load_kolors_text_encoder(file_path=file_path)
return
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
if self.is_stable_video_diffusion(state_dict):
self.load_stable_video_diffusion(state_dict, file_path=file_path)
elif self.is_animatediff(state_dict):
self.load_animatediff(state_dict, file_path=file_path)
elif self.is_animatediff_xl(state_dict):
self.load_animatediff_xl(state_dict, file_path=file_path)
elif self.is_controlnet(state_dict):
self.load_controlnet(state_dict, file_path=file_path)
elif self.is_stabe_diffusion_xl(state_dict):
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion(state_dict):
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
elif self.is_sd_lora(state_dict):
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
elif self.is_beautiful_prompt(state_dict):
self.load_beautiful_prompt(state_dict, file_path=file_path)
elif self.is_RIFE(state_dict):
self.load_RIFE(state_dict, file_path=file_path)
elif self.is_translator(state_dict):
self.load_translator(state_dict, file_path=file_path)
elif self.is_ipadapter(state_dict):
self.load_ipadapter(state_dict, file_path=file_path)
elif self.is_ipadapter_image_encoder(state_dict):
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
elif self.is_ipadapter_xl(state_dict):
self.load_ipadapter_xl(state_dict, file_path=file_path)
elif self.is_ipadapter_xl_image_encoder(state_dict):
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit(state_dict):
self.load_hunyuan_dit(state_dict, file_path=file_path)
elif self.is_diffusers_vae(state_dict):
self.load_diffusers_vae(state_dict, file_path=file_path)
elif self.is_ExVideo_StableVideoDiffusion(state_dict):
self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
elif self.is_stable_diffusion_3(state_dict):
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion_3_t5(state_dict):
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
elif self.is_kolors_unet(state_dict):
self.load_kolors_unet(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:
self.load_model(file_path, lora_alphas=lora_alphas)
def to(self, device):
for component in self.model:
if isinstance(self.model[component], list):
for model in self.model[component]:
model.to(device)
else:
self.model[component].to(device)
torch.cuda.empty_cache()
def get_model_with_model_path(self, model_path):
for component in self.model_path:
if isinstance(self.model_path[component], str):
if os.path.samefile(self.model_path[component], model_path):
return self.model[component]
elif isinstance(self.model_path[component], list):
for i, model_path_ in enumerate(self.model_path[component]):
if os.path.samefile(model_path_, model_path):
return self.model[component][i]
raise ValueError(f"Please load model {model_path} before you use it.")
def __getattr__(self, __name):
if __name in self.model:
return self.model[__name]
else:
return super.__getattribute__(__name)
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-6:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
from .model_manager import *

View File

@@ -1,15 +1,18 @@
from huggingface_hub import hf_hub_download
from modelscope import snapshot_download
import os, shutil
from typing_extensions import Literal, TypeAlias
from typing import List
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
def download_from_modelscope(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
return
else:
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
@@ -21,8 +24,43 @@ def download_from_modelscope(model_id, origin_file_path, local_dir):
def download_from_huggingface(model_id, origin_file_path, local_dir):
os.makedirs(local_dir, exist_ok=True)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
print(f" {os.path.basename(origin_file_path)} has been already in {local_dir}.")
return
else:
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
print(f" Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
Preset_model_website: TypeAlias = Literal[
"HuggingFace",
"ModelScope",
]
website_to_preset_models = {
"HuggingFace": preset_models_on_huggingface,
"ModelScope": preset_models_on_modelscope,
}
website_to_download_fn = {
"HuggingFace": download_from_huggingface,
"ModelScope": download_from_modelscope,
}
def download_models(
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
):
print(f"Downloading models: {model_id_list}")
downloaded_files = []
for model_id in model_id_list:
for website in downloading_priority:
if model_id in website_to_preset_models[website]:
for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
# Check if the file is downloaded.
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
if file_to_download in downloaded_files:
continue
# Download
website_to_download_fn[website](model_id, origin_file_path, local_dir)
if os.path.basename(origin_file_path) in os.listdir(local_dir):
downloaded_files.append(file_to_download)
return downloaded_files

View File

@@ -1,5 +1,4 @@
from .attention import Attention
from .tiler import TileWorker
from einops import repeat, rearrange
import math
import torch
@@ -399,7 +398,8 @@ class HunyuanDiT(torch.nn.Module):
hidden_states, _ = hidden_states.chunk(2, dim=1)
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return HunyuanDiTStateDictConverter()

View File

@@ -79,7 +79,8 @@ class HunyuanDiTCLIPTextEncoder(BertModel):
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return HunyuanDiTCLIPTextEncoderStateDictConverter()
@@ -131,7 +132,8 @@ class HunyuanDiTT5TextEncoder(T5EncoderModel):
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return HunyuanDiTT5TextEncoderStateDictConverter()

195
diffsynth/models/lora.py Normal file
View File

@@ -0,0 +1,195 @@
import torch
from .sd_unet import SDUNet
from .sdxl_unet import SDXLUNet
from .sd_text_encoder import SDTextEncoder
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sd3_dit import SD3DiT
from .hunyuan_dit import HunyuanDiT
class LoRAFromCivitai:
def __init__(self):
self.supported_model_classes = []
self.lora_prefix = []
self.renamed_lora_prefix = {}
self.special_keys = {}
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
for special_key in self.special_keys:
target_name = target_name.replace(special_key, self.special_keys[special_key])
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
if model_resource == "diffusers":
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
elif model_resource == "civitai":
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
for model_resource in ["diffusers", "civitai"]:
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
else model.__class__.state_dict_converter().from_civitai
state_dict_lora_ = converter_fn(state_dict_lora_)
if len(state_dict_lora_) == 0:
continue
for name in state_dict_lora_:
if name not in state_dict_model:
break
else:
return lora_prefix, model_resource
except:
pass
return None
class SDLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [SDUNet, SDTextEncoder]
self.lora_prefix = ["lora_unet_", "lora_te_"]
self.special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
"text.model": "text_model",
"self.attn.q.proj": "self_attn.q_proj",
"self.attn.k.proj": "self_attn.k_proj",
"self.attn.v.proj": "self_attn.v_proj",
"self.attn.out.proj": "self_attn.out_proj",
"input.blocks": "model.diffusion_model.input_blocks",
"middle.block": "model.diffusion_model.middle_block",
"output.blocks": "model.diffusion_model.output_blocks",
}
class SDXLLoRAFromCivitai(LoRAFromCivitai):
def __init__(self):
super().__init__()
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
self.renamed_lora_prefix = {"lora_te2_": "2"}
self.special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
"text.model": "conditioner.embedders.0.transformer.text_model",
"self.attn.q.proj": "self_attn.q_proj",
"self.attn.k.proj": "self_attn.k_proj",
"self.attn.v.proj": "self_attn.v_proj",
"self.attn.out.proj": "self_attn.out_proj",
"input.blocks": "model.diffusion_model.input_blocks",
"middle.block": "model.diffusion_model.middle_block",
"output.blocks": "model.diffusion_model.output_blocks",
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
}
class GeneralLoRAFromPeft:
def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT]
def convert_state_dict(self, state_dict, alpha=1.0, device="cuda", torch_dtype=torch.float16):
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict()
for name, param in state_dict_model.items():
torch_dtype = param.dtype
device = param.device
break
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, device=device, torch_dtype=torch_dtype)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for model_class in self.supported_model_classes:
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0)
if len(state_dict_lora_) == 0:
continue
for name in state_dict_lora_:
if name not in state_dict_model:
break
else:
return "", ""
except:
pass
return None

View File

@@ -0,0 +1,536 @@
import os, torch, hashlib, json, importlib
from safetensors import safe_open
from torch import Tensor
from typing_extensions import Literal, TypeAlias
from typing import List
from .downloader import download_models, Preset_model_id, Preset_model_website
from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
from .sdxl_vae_decoder import SDXLVAEDecoder
from .sdxl_vae_encoder import SDXLVAEEncoder
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from .sd3_dit import SD3DiT
from .sd3_vae_decoder import SD3VAEDecoder
from .sd3_vae_encoder import SD3VAEEncoder
from .sd_controlnet import SDControlNet
from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel
from .svd_image_encoder import SVDImageEncoder
from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-6:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
if isinstance(key, str):
if isinstance(value, Tensor):
if with_shape:
shape = "_".join(map(str, list(value.shape)))
keys.append(key + ":" + shape)
keys.append(key)
elif isinstance(value, dict):
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
keys.sort()
keys_str = ",".join(keys)
return keys_str
def split_state_dict_with_prefix(state_dict):
keys = sorted([key for key in state_dict if isinstance(key, str)])
prefix_dict = {}
for key in keys:
prefix = key if "." not in key else key.split(".")[0]
if prefix not in prefix_dict:
prefix_dict[prefix] = []
prefix_dict[prefix].append(key)
state_dicts = []
for prefix, keys in prefix_dict.items():
sub_state_dict = {key: state_dict[key] for key in keys}
state_dicts.append(sub_state_dict)
return state_dicts
def hash_state_dict_keys(state_dict, with_shape=True):
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
keys_str = keys_str.encode(encoding="UTF-8")
return hashlib.md5(keys_str).hexdigest()
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
print(f" model_name: {model_name} model_class: {model_class.__name__}")
state_dict_converter = model_class.state_dict_converter()
if model_resource == "civitai":
state_dict_results = state_dict_converter.from_civitai(state_dict)
elif model_resource == "diffusers":
state_dict_results = state_dict_converter.from_diffusers(state_dict)
if isinstance(state_dict_results, tuple):
model_state_dict, extra_kwargs = state_dict_results
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
else:
model_state_dict, extra_kwargs = state_dict_results, {}
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
model.load_state_dict(model_state_dict)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
if torch_dtype == torch.float16 and hasattr(model, "half"):
model = model.half()
model = model.to(device=device)
loaded_model_names.append(model_name)
loaded_models.append(model)
return loaded_model_names, loaded_models
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
base_state_dict = base_model.state_dict()
base_model.to("cpu")
del base_model
model = model_class(**extra_kwargs)
model.load_state_dict(base_state_dict, strict=False)
model.load_state_dict(state_dict, strict=False)
model.to(dtype=torch_dtype, device=device)
return model
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
loaded_model_names, loaded_models = [], []
for model_name, model_class in zip(model_names, model_classes):
while True:
for model_id in range(len(model_manager.model)):
base_model_name = model_manager.model_name[model_id]
if base_model_name == model_name:
base_model_path = model_manager.model_path[model_id]
base_model = model_manager.model[model_id]
print(f" Adding patch model to {base_model_name} ({base_model_path})")
patched_model = load_single_patch_model_from_single_file(
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
loaded_model_names.append(base_model_name)
loaded_models.append(patched_model)
model_manager.model.pop(model_id)
model_manager.model_path.pop(model_id)
model_manager.model_name.pop(model_id)
break
else:
break
return loaded_model_names, loaded_models
class ModelDetectorTemplate:
def __init__(self):
pass
def match(self, file_path="", state_dict={}):
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
return [], []
class ModelDetectorFromSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
self.keys_hash_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
if keys_hash is not None:
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
# Load models without strict matching
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
if keys_hash in self.keys_hash_dict:
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
return loaded_model_names, loaded_models
return loaded_model_names, loaded_models
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def __init__(self, model_loader_configs=[]):
super().__init__(model_loader_configs)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
splited_state_dict = split_state_dict_with_prefix(state_dict)
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
# Split the state_dict and load from each component
splited_state_dict = split_state_dict_with_prefix(state_dict)
valid_state_dict = {}
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
valid_state_dict.update(sub_state_dict)
if super().match(file_path, valid_state_dict):
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
else:
loaded_model_names, loaded_models = [], []
for sub_state_dict in splited_state_dict:
if super().match(file_path, sub_state_dict):
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromHuggingfaceFolder:
def __init__(self, model_loader_configs=[]):
self.architecture_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name):
self.architecture_dict[architecture] = (huggingface_lib, model_name)
def match(self, file_path="", state_dict={}):
if os.path.isfile(file_path):
return False
file_list = os.listdir(file_path)
if "config.json" not in file_list:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config:
return False
return True
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
for architecture in config["architectures"]:
huggingface_lib, model_name = self.architecture_dict[architecture]
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelDetectorFromPatchedSingleFile:
def __init__(self, model_loader_configs=[]):
self.keys_hash_with_shape_dict = {}
for metadata in model_loader_configs:
self.add_model_metadata(*metadata)
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
return True
return False
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
# Load models with strict matching
loaded_model_names, loaded_models = [], []
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
loaded_model_names += loaded_model_names_
loaded_models += loaded_models_
return loaded_model_names, loaded_models
class ModelManager:
def __init__(
self,
torch_dtype=torch.float16,
device="cuda",
model_id_list: List[Preset_model_id] = [],
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
file_path_list: List[str] = [],
):
self.torch_dtype = torch_dtype
self.device = device
self.model = []
self.model_path = []
self.model_name = []
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
self.model_detector = [
ModelDetectorFromSingleFile(model_loader_configs),
ModelDetectorFromSplitedSingleFile(model_loader_configs),
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
]
self.load_models(downloaded_files + file_path_list)
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
print(f"Loading models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
print(f"Loading models from folder: {file_path}")
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
print(f"Loading patch models from file: {file_path}")
model_names, models = load_patch_model_from_single_file(
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following patched models are loaded: {model_names}.")
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
print(f"Loading LoRA models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
break
def load_model(self, file_path, model_names=None):
print(f"Loading models from: {file_path}")
if os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
state_dict = None
for model_detector in self.model_detector:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path, state_dict,
device=self.device, torch_dtype=self.torch_dtype,
allowed_model_names=model_names, model_manager=self
)
for model_name, model in zip(model_names, models):
self.model.append(model)
self.model_path.append(file_path)
self.model_name.append(model_name)
print(f" The following models are loaded: {model_names}.")
break
else:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None):
for file_path in file_path_list:
self.load_model(file_path, model_names)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
if file_path is not None and file_path != model_path:
continue
if model_name == model_name_:
fetched_models.append(model)
fetched_model_paths.append(model_path)
if len(fetched_models) == 0:
print(f"No {model_name} models available.")
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}.")
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
else:
return fetched_models[0]
def to(self, device):
for model in self.model:
model.to(device)

View File

@@ -228,7 +228,8 @@ class SD3DiT(torch.nn.Module):
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3DiTStateDictConverter()

View File

@@ -19,7 +19,8 @@ class SD3TextEncoder1(SDTextEncoder):
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
return pooled_embeds, hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3TextEncoder1StateDictConverter()
@@ -28,7 +29,8 @@ class SD3TextEncoder2(SDXLTextEncoder2):
def __init__(self):
super().__init__()
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3TextEncoder2StateDictConverter()
@@ -72,7 +74,8 @@ class SD3TextEncoder3(T5EncoderModel):
prompt_emb = outputs.last_hidden_state
return prompt_emb
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SD3TextEncoder3StateDictConverter()

View File

@@ -76,5 +76,6 @@ class SD3VAEDecoder(torch.nn.Module):
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDVAEDecoderStateDictConverter()

View File

@@ -90,5 +90,6 @@ class SD3VAEEncoder(torch.nn.Module):
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDVAEEncoderStateDictConverter()

View File

@@ -99,7 +99,7 @@ class SDControlNet(torch.nn.Module):
tiled=False, tile_size=64, tile_stride=32,
):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_proj(timestep).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
time_emb = time_emb.repeat(sample.shape[0], 1)
@@ -134,7 +134,8 @@ class SDControlNet(torch.nn.Module):
return controlnet_res_stack
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDControlNetStateDictConverter()

View File

@@ -47,7 +47,8 @@ class SDIpAdapter(torch.nn.Module):
}
return ip_kv_dict
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDIpAdapterStateDictConverter()

View File

@@ -1,60 +0,0 @@
import torch
from .sd_unet import SDUNetStateDictConverter, SDUNet
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
class SDLoRA:
def __init__(self):
pass
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
special_keys = {
"down.blocks": "down_blocks",
"up.blocks": "up_blocks",
"mid.block": "mid_block",
"proj.in": "proj_in",
"proj.out": "proj_out",
"transformer.blocks": "transformer_blocks",
"to.q": "to_q",
"to.k": "to_k",
"to.v": "to_v",
"to.out": "to_out",
}
state_dict_ = {}
for key in state_dict:
if ".lora_up" not in key:
continue
if not key.startswith(lora_prefix):
continue
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
for special_key in special_keys:
target_name = target_name.replace(special_key, special_keys[special_key])
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
state_dict_unet = unet.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
if len(state_dict_lora) > 0:
for name in state_dict_lora:
state_dict_unet[name] += state_dict_lora[name].to(device=device)
unet.load_state_dict(state_dict_unet)
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
state_dict_text_encoder = text_encoder.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
if len(state_dict_lora) > 0:
for name in state_dict_lora:
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
text_encoder.load_state_dict(state_dict_text_encoder)

View File

@@ -144,7 +144,8 @@ class SDMotionModel(torch.nn.Module):
def forward(self):
pass
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDMotionModelStateDictConverter()

View File

@@ -71,7 +71,8 @@ class SDTextEncoder(torch.nn.Module):
embeds = self.final_layer_norm(embeds)
return embeds
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDTextEncoderStateDictConverter()

View File

@@ -323,7 +323,7 @@ class SDUNet(torch.nn.Module):
def forward(self, sample, timestep, encoder_hidden_states, **kwargs):
# 1. time
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
time_emb = self.time_proj(timestep).to(sample.dtype)
time_emb = self.time_embedding(time_emb)
# 2. pre-process
@@ -342,7 +342,8 @@ class SDUNet(torch.nn.Module):
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDUNetStateDictConverter()

View File

@@ -90,6 +90,8 @@ class SDVAEDecoder(torch.nn.Module):
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
original_dtype = sample.dtype
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
@@ -110,10 +112,12 @@ class SDVAEDecoder(torch.nn.Module):
hidden_states = self.conv_norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states.to(original_dtype)
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDVAEDecoderStateDictConverter()

View File

@@ -50,6 +50,8 @@ class SDVAEEncoder(torch.nn.Module):
return hidden_states
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
original_dtype = sample.dtype
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
# For VAE Decoder, we do not need to apply the tiler on each layer.
if tiled:
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
@@ -71,6 +73,7 @@ class SDVAEEncoder(torch.nn.Module):
hidden_states = self.quant_conv(hidden_states)
hidden_states = hidden_states[:, :4]
hidden_states *= self.scaling_factor
hidden_states = hidden_states.to(original_dtype)
return hidden_states
@@ -91,7 +94,8 @@ class SDVAEEncoder(torch.nn.Module):
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDVAEEncoderStateDictConverter()

View File

@@ -96,7 +96,8 @@ class SDXLIpAdapter(torch.nn.Module):
}
return ip_kv_dict
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLIpAdapterStateDictConverter()

View File

@@ -49,7 +49,8 @@ class SDXLMotionModel(torch.nn.Module):
def forward(self):
pass
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDMotionModelStateDictConverter()

View File

@@ -36,7 +36,8 @@ class SDXLTextEncoder(torch.nn.Module):
break
return embeds
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLTextEncoderStateDictConverter()
@@ -80,7 +81,8 @@ class SDXLTextEncoder2(torch.nn.Module):
pooled_embeds = self.text_projection(pooled_embeds)
return pooled_embeds, hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLTextEncoder2StateDictConverter()

View File

@@ -91,7 +91,7 @@ class SDXLUNet(torch.nn.Module):
**kwargs
):
# 1. time
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
t_emb = self.time_proj(timestep).to(sample.dtype)
t_emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(add_time_id)
@@ -133,7 +133,8 @@ class SDXLUNet(torch.nn.Module):
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLUNetStateDictConverter()
@@ -197,7 +198,10 @@ class SDXLUNetStateDictConverter:
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
if "text_intermediate_proj.weight" in state_dict_:
return state_dict_, {"is_kolors": True}
else:
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
@@ -1889,4 +1893,7 @@ class SDXLUNetStateDictConverter:
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
if "text_intermediate_proj.weight" in state_dict_:
return state_dict_, {"is_kolors": True}
else:
return state_dict_

View File

@@ -2,14 +2,23 @@ from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
class SDXLVAEDecoder(SDVAEDecoder):
def __init__(self):
def __init__(self, upcast_to_float32=True):
super().__init__()
self.scaling_factor = 0.13025
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLVAEDecoderStateDictConverter()
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
def __init__(self):
super().__init__()
def from_diffusers(self, state_dict):
state_dict = super().from_diffusers(state_dict)
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
state_dict = super().from_civitai(state_dict)
return state_dict, {"upcast_to_float32": True}

View File

@@ -2,14 +2,23 @@ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
class SDXLVAEEncoder(SDVAEEncoder):
def __init__(self):
def __init__(self, upcast_to_float32=True):
super().__init__()
self.scaling_factor = 0.13025
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SDXLVAEEncoderStateDictConverter()
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
def __init__(self):
super().__init__()
def from_diffusers(self, state_dict):
state_dict = super().from_diffusers(state_dict)
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
state_dict = super().from_civitai(state_dict)
return state_dict, {"upcast_to_float32": True}

View File

@@ -44,7 +44,8 @@ class SVDImageEncoder(torch.nn.Module):
embeds = self.visual_projection(embeds)
return embeds
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SVDImageEncoderStateDictConverter()

View File

@@ -407,7 +407,8 @@ class SVDUNet(torch.nn.Module):
return hidden_states
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SVDUNetStateDictConverter()

View File

@@ -199,7 +199,8 @@ class SVDVAEDecoder(torch.nn.Module):
return values
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SVDVAEDecoderStateDictConverter()

View File

@@ -6,7 +6,8 @@ class SVDVAEEncoder(SDVAEEncoder):
super().__init__()
self.scaling_factor = 0.13025
def state_dict_converter(self):
@staticmethod
def state_dict_converter():
return SVDVAEEncoderStateDictConverter()

View File

@@ -1,8 +1,8 @@
from .stable_diffusion import SDImagePipeline
from .stable_diffusion_xl import SDXLImagePipeline
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
from .stable_diffusion_xl_video import SDXLVideoPipeline
from .stable_video_diffusion import SVDVideoPipeline
from .hunyuan_dit import HunyuanDiTImagePipeline
from .stable_diffusion_3 import SD3ImagePipeline
from .kwai_kolors import KolorsImagePipeline
from .sd_image import SDImagePipeline
from .sd_video import SDVideoPipeline
from .sdxl_image import SDXLImagePipeline
from .sdxl_video import SDXLVideoPipeline
from .sd3_image import SD3ImagePipeline
from .hunyuan_image import HunyuanDiTImagePipeline
from .svd_video import SVDVideoPipeline
from .pipeline_runner import SDVideoPipelineRunner

View File

@@ -0,0 +1,34 @@
import torch
import numpy as np
from PIL import Image
class BasePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.device = device
self.torch_dtype = torch_dtype
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def preprocess_images(self, images):
return [self.preprocess_image(image) for image in images]
def vae_output_to_image(self, vae_output):
image = vae_output[0].cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def vae_output_to_video(self, vae_output):
video = vae_output.cpu().permute(1, 2, 0).numpy()
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
return video

View File

@@ -22,6 +22,10 @@ def lets_dance(
device = "cuda",
vram_limit_level = 0,
):
# 0. Text embedding alignment (only for video processing)
if encoder_hidden_states.shape[0] != sample.shape[0]:
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0], 1, 1, 1)
# 1. ControlNet
# This part will be repeated on overlapping frames if animatediff_batch_size > animatediff_stride.
# I leave it here because I intend to do something interesting on the ControlNets.
@@ -50,7 +54,7 @@ def lets_dance(
additional_res_stack = None
# 2. time
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
time_emb = unet.time_proj(timestep).to(sample.dtype)
time_emb = unet.time_embedding(time_emb)
# 3. pre-process
@@ -133,7 +137,7 @@ def lets_dance_xl(
vram_limit_level = 0,
):
# 2. time
t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
t_emb = unet.time_proj(timestep).to(sample.dtype)
t_emb = unet.time_embedding(t_emb)
time_embeds = unet.add_time_proj(add_time_id)

View File

@@ -3,11 +3,11 @@ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, Hunyuan
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models import ModelManager
from ..prompts import HunyuanDiTPrompter
from ..prompters import HunyuanDiTPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
@@ -122,14 +122,12 @@ class ImageSizeManager:
class HunyuanDiTImagePipeline(torch.nn.Module):
class HunyuanDiTImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
self.prompter = HunyuanDiTPrompter()
self.device = device
self.torch_dtype = torch_dtype
self.image_size_manager = ImageSizeManager()
# models
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
@@ -139,42 +137,60 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
self.vae_encoder: SDXLVAEEncoder = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder
self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder
self.dit = model_manager.hunyuan_dit
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def denoising_model(self):
return self.dit
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
self.text_encoder_t5 = model_manager.fetch_model("hunyuan_dit_t5_text_encoder")
self.dit = model_manager.fetch_model("hunyuan_dit")
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
self.prompter.fetch_models(self.text_encoder, self.text_encoder_t5)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager):
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
pipe = HunyuanDiTImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_models(model_manager, prompt_refiner_classes)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1):
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=1, positive=True):
text_emb, text_emb_mask, text_emb_t5, text_emb_mask_t5 = self.prompter.encode_prompt(
prompt,
clip_skip=clip_skip,
clip_skip_2=clip_skip_2,
positive=positive,
device=self.device
)
return {
"text_emb": text_emb,
"text_emb_mask": text_emb_mask,
"text_emb_t5": text_emb_t5,
"text_emb_mask_t5": text_emb_mask_t5
}
def prepare_extra_input(self, latents=None, tiled=False, tile_size=64, tile_stride=32):
batch_size, height, width = latents.shape[0], latents.shape[2] * 8, latents.shape[3] * 8
if tiled:
height, width = tile_size * 16, tile_size * 16
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
@@ -198,7 +214,6 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
clip_skip=1,
clip_skip_2=1,
input_image=None,
reference_images=[],
reference_strengths=[0.4],
denoising_strength=1.0,
height=1024,
@@ -222,65 +237,26 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
else:
latents = noise.clone()
# Prepare reference latents
reference_latents = []
for reference_image in reference_images:
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype))
# Encode prompts
prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_t5,
prompt,
clip_skip=clip_skip,
clip_skip_2=clip_skip_2,
positive=True,
device=self.device
)
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_t5,
negative_prompt,
clip_skip=clip_skip,
clip_skip_2=clip_skip_2,
positive=False,
device=self.device
)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
# Prepare positional id
extra_input = self.prepare_extra_input(height, width, tiled, tile_size)
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
# In-context reference
for reference_latents_, reference_strength in zip(reference_latents, reference_strengths):
if progress_id < num_inference_steps * reference_strength:
noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id])
self.dit(
noisy_reference_latents,
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
timestep,
**extra_input,
to_cache=True
)
# Positive side
noise_pred_posi = self.dit(
latents,
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
timestep,
**extra_input,
latents, timestep=timestep, **prompt_emb_posi, **extra_input,
)
if cfg_scale != 1.0:
# Negative side
noise_pred_nega = self.dit(
latents,
prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega,
timestep,
**extra_input
latents, timestep=timestep, **prompt_emb_nega, **extra_input,
)
# Classifier-free guidance
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)

View File

@@ -1,168 +0,0 @@
from ..models import ModelManager, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.kolors_text_encoder import ChatGLMModel
from ..prompts import KolorsPrompter
from ..schedulers import EnhancedDDIMScheduler
from .dancer import lets_dance_xl
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class KolorsImagePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
self.prompter = KolorsPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: ChatGLMModel = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.kolors_text_encoder
self.unet = model_manager.kolors_unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_ipadapter(self, model_manager: ModelManager):
if "ipadapter_xl" in model_manager.model:
self.ipadapter = model_manager.ipadapter_xl
if "ipadapter_xl_image_encoder" in model_manager.model:
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager):
pipe = KolorsImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_ipadapter(model_manager)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=2,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_use_instant_style=False,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
self.text_encoder,
prompt,
clip_skip=clip_skip,
device=self.device,
positive=True,
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
self.text_encoder,
negative_prompt,
clip_skip=clip_skip,
device=self.device,
positive=False,
)
# Prepare positional id
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
# IP-Adapter
if ipadapter_images is not None:
if ipadapter_use_instant_style:
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, timestep, latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -0,0 +1,105 @@
import os, torch, json
from .sd_video import ModelManager, SDVideoPipeline, ControlNetConfigUnit
from ..processors.sequencial_processor import SequencialProcessor
from ..data import VideoData, save_frames, save_video
class SDVideoPipelineRunner:
def __init__(self, in_streamlit=False):
self.in_streamlit = in_streamlit
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
model_manager.load_models(model_list)
pipe = SDVideoPipeline.from_model_manager(
model_manager,
[
ControlNetConfigUnit(
processor_id=unit["processor_id"],
model_path=unit["model_path"],
scale=unit["scale"]
) for unit in controlnet_units
]
)
textual_inversion_paths = []
for file_name in os.listdir(textual_inversion_folder):
if file_name.endswith(".pt") or file_name.endswith(".bin") or file_name.endswith(".pth") or file_name.endswith(".safetensors"):
textual_inversion_paths.append(os.path.join(textual_inversion_folder, file_name))
pipe.prompter.load_textual_inversions(textual_inversion_paths)
return model_manager, pipe
def load_smoother(self, model_manager, smoother_configs):
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
return smoother
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
torch.manual_seed(seed)
if self.in_streamlit:
import streamlit as st
progress_bar_st = st.progress(0.0)
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
progress_bar_st.progress(1.0)
else:
output_video = pipe(**pipeline_inputs, smoother=smoother)
model_manager.to("cpu")
return output_video
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
if start_frame_id is None:
start_frame_id = 0
if end_frame_id is None:
end_frame_id = len(video)
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
return frames
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
if len(data["controlnet_frames"]) > 0:
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
return pipeline_inputs
def save_output(self, video, output_folder, fps, config):
os.makedirs(output_folder, exist_ok=True)
save_frames(video, os.path.join(output_folder, "frames"))
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
with open(os.path.join(output_folder, "config.json"), 'w') as file:
json.dump(config, file, indent=4)
def run(self, config):
if self.in_streamlit:
import streamlit as st
if self.in_streamlit: st.markdown("Loading videos ...")
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Loading videos ... done!")
if self.in_streamlit: st.markdown("Loading models ...")
model_manager, pipe = self.load_pipeline(**config["models"])
if self.in_streamlit: st.markdown("Loading models ... done!")
if "smoother_configs" in config:
if self.in_streamlit: st.markdown("Loading smoother ...")
smoother = self.load_smoother(model_manager, config["smoother_configs"])
if self.in_streamlit: st.markdown("Loading smoother ... done!")
else:
smoother = None
if self.in_streamlit: st.markdown("Synthesizing videos ...")
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
if self.in_streamlit: st.markdown("Saving videos ...")
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
if self.in_streamlit: st.markdown("Saving videos ... done!")
if self.in_streamlit: st.markdown("Finished!")
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
if self.in_streamlit: st.video(video_file.read())

View File

@@ -1,20 +1,18 @@
from ..models import ModelManager, SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEDecoder, SD3VAEEncoder
from ..prompts import SD3Prompter
from ..prompters import SD3Prompter
from ..schedulers import FlowMatchScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SD3ImagePipeline(torch.nn.Module):
class SD3ImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = FlowMatchScheduler()
self.prompter = SD3Prompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: SD3TextEncoder2 = None
@@ -24,43 +22,54 @@ class SD3ImagePipeline(torch.nn.Module):
self.vae_encoder: SD3VAEEncoder = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.sd3_text_encoder_1
self.text_encoder_2 = model_manager.sd3_text_encoder_2
def denoising_model(self):
return self.dit
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
if "sd3_text_encoder_3" in model_manager.model:
self.text_encoder_3 = model_manager.sd3_text_encoder_3
self.dit = model_manager.sd3_dit
self.vae_decoder = model_manager.sd3_vae_decoder
self.vae_encoder = model_manager.sd3_vae_encoder
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
self.dit = model_manager.fetch_model("sd3_dit")
self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2, self.text_encoder_3)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager):
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
pipe = SD3ImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_models(model_manager, prompt_refiner_classes)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, positive=True):
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
def prepare_extra_input(self, latents=None):
return {}
@torch.no_grad()
def __call__(
self,
@@ -78,42 +87,35 @@ class SD3ImagePipeline(torch.nn.Module):
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
prompt_emb_posi, pooled_prompt_emb_posi = self.prompter.encode_prompt(
self.text_encoder_1, self.text_encoder_2, self.text_encoder_3,
prompt,
device=self.device, positive=True
)
prompt_emb_nega, pooled_prompt_emb_nega = self.prompter.encode_prompt(
self.text_encoder_1, self.text_encoder_2, self.text_encoder_3,
negative_prompt,
device=self.device, positive=False
)
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.Tensor((timestep,)).to(self.device)
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = self.dit(
latents, timestep, prompt_emb_posi, pooled_prompt_emb_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
)
noise_pred_nega = self.dit(
latents, timestep, prompt_emb_nega, pooled_prompt_emb_nega,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)

View File

@@ -1,23 +1,22 @@
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompts import SDPrompter
from ..prompters import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDImagePipeline(torch.nn.Module):
class SDImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
@@ -28,61 +27,65 @@ class SDImagePipeline(torch.nn.Module):
self.ipadapter: SDIpAdapter = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def denoising_model(self):
return self.unet
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
self.unet = model_manager.fetch_model("sd_unet")
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.get_model_with_model_path(config.model_path),
model_manager.fetch_model("sd_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_ipadapter(self, model_manager: ModelManager):
if "ipadapter" in model_manager.model:
self.ipadapter = model_manager.ipadapter
if "ipadapter_image_encoder" in model_manager.model:
self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
pipe = SDImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
pipe.fetch_ipadapter(model_manager)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes=[])
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, clip_skip=1, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, clip_skip=clip_skip, device=self.device, positive=positive)
return {"encoder_hidden_states": prompt_emb}
def prepare_extra_input(self, latents=None):
return {}
@torch.no_grad()
def __call__(
self,
@@ -104,53 +107,56 @@ class SDImagePipeline(torch.nn.Module):
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
# IP-Adapter
if ipadapter_images is not None:
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_image is not None:
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
controlnet_kwargs = {"controlnet_frames": controlnet_image}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
device=self.device, vram_limit_level=0
sample=latents, timestep=timestep,
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
device=self.device,
)
noise_pred_nega = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
device=self.device, vram_limit_level=0
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)

View File

@@ -0,0 +1,266 @@
from ..models import SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder, SDMotionModel
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from .sd_image import SDImagePipeline
from .dancer import lets_dance
from typing import List
import torch
from tqdm import tqdm
def lets_dance_with_long_video(
unet: SDUNet,
motion_modules: SDMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
timestep = None,
encoder_hidden_states = None,
ipadapter_kwargs_list = {},
controlnet_frames = None,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
tiled=False,
tile_size=64,
tile_stride=32,
device="cuda",
animatediff_batch_size=16,
animatediff_stride=8,
):
num_frames = sample.shape[0]
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
for batch_id in range(0, num_frames, animatediff_stride):
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
# process this batch
hidden_states_batch = lets_dance(
unet, motion_modules, controlnet,
sample[batch_id: batch_id_].to(device),
timestep,
encoder_hidden_states,
ipadapter_kwargs_list=ipadapter_kwargs_list,
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, device=device
).cpu()
# update hidden_states
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
hidden_states, num = hidden_states_output[i]
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
hidden_states_output[i] = (hidden_states, num + bias)
if batch_id_ == num_frames:
break
# output
hidden_states = torch.stack([h for h, _ in hidden_states_output])
return hidden_states
class SDVideoPipeline(SDImagePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
self.prompter = SDPrompter()
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
self.ipadapter: SDIpAdapter = None
self.motion_modules: SDMotionModel = None
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sd_text_encoder")
self.unet = model_manager.fetch_model("sd_unet")
self.vae_decoder = model_manager.fetch_model("sd_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sd_vae_encoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.fetch_model("sd_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sd_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sd_ipadapter_clip_image_encoder")
# Motion Modules
self.motion_modules = model_manager.fetch_model("sd_motion_modules")
if self.motion_modules is None:
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
pipe = SDVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
return pipe
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents.append(latent.cpu())
latents = torch.concat(latents, dim=0)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
num_frames=None,
input_frames=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
smoother=None,
smoother_progress_ids=[],
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters, batch size ...
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
other_kwargs = {
"animatediff_batch_size": animatediff_batch_size, "animatediff_stride": animatediff_stride,
"unet_batch_size": unet_batch_size, "controlnet_batch_size": controlnet_batch_size,
"cross_frame_attention": cross_frame_attention,
}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if self.motion_modules is None:
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else:
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0:
latents = noise
else:
latents = self.encode_video(input_frames, **tiler_kwargs)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
# IP-Adapter
if ipadapter_images is not None:
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_frames is not None:
if isinstance(controlnet_frames[0], list):
controlnet_frames_ = []
for processor_id in range(len(controlnet_frames)):
controlnet_frames_.append(
torch.stack([
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
], dim=1)
)
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
else:
controlnet_frames = torch.stack([
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames)
], dim=1)
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep,
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **other_kwargs, **tiler_kwargs,
device=self.device,
)
noise_pred_nega = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep,
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **other_kwargs, **tiler_kwargs,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids:
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
rendered_frames = self.decode_video(rendered_frames)
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
target_latents = self.encode_video(rendered_frames)
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_video(latents, **tiler_kwargs)
# Post-process
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
output_frames = smoother(output_frames, original_frames=input_frames)
return image

View File

@@ -0,0 +1,191 @@
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.kolors_text_encoder import ChatGLMModel
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDXLPrompter, KolorsPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
from .dancer import lets_dance_xl
from typing import List
import torch
from tqdm import tqdm
class SDXLImagePipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDXLPrompter()
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.text_encoder_kolors: ChatGLMModel = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
# self.controlnet: MultiControlNetManager = None (TODO)
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
def denoising_model(self):
return self.unet
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
self.unet = model_manager.fetch_model("sdxl_unet")
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
# ControlNets (TODO)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
# Kolors
if self.text_encoder_kolors is not None:
print("Switch to Kolors. The prompter and scheduler will be replaced.")
self.prompter = KolorsPrompter()
self.prompter.fetch_models(self.text_encoder_kolors)
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
else:
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
pipe = SDXLImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
return pipe
def encode_image(self, image, tiled=False, tile_size=64, tile_stride=32):
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return latents
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.vae_output_to_image(image)
return image
def encode_prompt(self, prompt, clip_skip=1, clip_skip_2=2, positive=True):
add_prompt_emb, prompt_emb = self.prompter.encode_prompt(
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=positive,
)
return {"encoder_hidden_states": prompt_emb, "add_text_embeds": add_prompt_emb}
def prepare_extra_input(self, latents=None):
height, width = latents.shape[2] * 8, latents.shape[3] * 8
return {"add_time_id": torch.tensor([height, width, 0, 0, height, width], device=self.device)}
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=2,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_use_instant_style=False,
controlnet_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
# IP-Adapter
if ipadapter_images is not None:
if ipadapter_use_instant_style:
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets (TODO)
controlnet_kwargs = {"controlnet_frames": None}
# Prepare extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet, motion_modules=None, controlnet=None,
sample=latents, timestep=timestep, **extra_input,
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
device=self.device,
)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet, motion_modules=None, controlnet=None,
sample=latents, timestep=timestep, **extra_input,
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -0,0 +1,223 @@
from ..models import SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder, SDXLMotionModel
from ..models.kolors_text_encoder import ChatGLMModel
from ..models.model_manager import ModelManager
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import SDXLPrompter, KolorsPrompter
from ..schedulers import EnhancedDDIMScheduler
from .sdxl_image import SDXLImagePipeline
from .dancer import lets_dance_xl
from typing import List
import torch
from tqdm import tqdm
class SDXLVideoPipeline(SDXLImagePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_original_animatediff=True):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_original_animatediff else "scaled_linear")
self.prompter = SDXLPrompter()
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.text_encoder_kolors: ChatGLMModel = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
# self.controlnet: MultiControlNetManager = None (TODO)
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
self.motion_modules: SDXLMotionModel = None
def fetch_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
# Main models
self.text_encoder = model_manager.fetch_model("sdxl_text_encoder")
self.text_encoder_2 = model_manager.fetch_model("sdxl_text_encoder_2")
self.text_encoder_kolors = model_manager.fetch_model("kolors_text_encoder")
self.unet = model_manager.fetch_model("sdxl_unet")
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
# ControlNets (TODO)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("sdxl_ipadapter_clip_image_encoder")
# Motion Modules
self.motion_modules = model_manager.fetch_model("sdxl_motion_modules")
if self.motion_modules is None:
self.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
# Kolors
if self.text_encoder_kolors is not None:
print("Switch to Kolors. The prompter will be replaced.")
self.prompter = KolorsPrompter()
self.prompter.fetch_models(self.text_encoder_kolors)
# The schedulers of AniamteDiff and Kolors are incompatible. We align it with AniamteDiff.
if self.motion_modules is None:
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
else:
self.prompter.fetch_models(self.text_encoder, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[]):
pipe = SDXLVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes)
return pipe
def decode_video(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_video(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents.append(latent.cpu())
latents = torch.concat(latents, dim=0)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
num_frames=None,
input_frames=None,
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_use_instant_style=False,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
smoother=None,
smoother_progress_ids=[],
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters, batch size ...
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if self.motion_modules is None:
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else:
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0:
latents = noise
else:
latents = self.encode_video(input_frames, **tiler_kwargs)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
latents = latents.to(self.device) # will be deleted for supporting long videos
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
# IP-Adapter
if ipadapter_images is not None:
if ipadapter_use_instant_style:
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_frames is not None:
if isinstance(controlnet_frames[0], list):
controlnet_frames_ = []
for processor_id in range(len(controlnet_frames)):
controlnet_frames_.append(
torch.stack([
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
], dim=1)
)
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
else:
controlnet_frames = torch.stack([
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames)
], dim=1)
controlnet_kwargs = {"controlnet_frames": controlnet_frames}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Prepare extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet, motion_modules=self.motion_modules, controlnet=None,
sample=latents, timestep=timestep,
**prompt_emb_posi, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **extra_input, **tiler_kwargs,
device=self.device,
)
noise_pred_nega = lets_dance_xl(
self.unet, motion_modules=self.motion_modules, controlnet=None,
sample=latents, timestep=timestep,
**prompt_emb_nega, **controlnet_kwargs, **ipadapter_kwargs_list_nega, **extra_input, **tiler_kwargs,
device=self.device,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids:
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
rendered_frames = self.decode_video(rendered_frames)
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
target_latents = self.encode_video(rendered_frames)
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_video(latents, **tiler_kwargs)
# Post-process
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
output_frames = smoother(output_frames, original_frames=input_frames)
return image

View File

@@ -1,356 +0,0 @@
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDMotionModel
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler
from ..data import VideoData, save_frames, save_video
from .dancer import lets_dance
from ..processors.sequencial_processor import SequencialProcessor
from typing import List
import torch, os, json
from tqdm import tqdm
from PIL import Image
import numpy as np
def lets_dance_with_long_video(
unet: SDUNet,
motion_modules: SDMotionModel = None,
controlnet: MultiControlNetManager = None,
sample = None,
timestep = None,
encoder_hidden_states = None,
controlnet_frames = None,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
device = "cuda",
vram_limit_level = 0,
):
num_frames = sample.shape[0]
hidden_states_output = [(torch.zeros(sample[0].shape, dtype=sample[0].dtype), 0) for i in range(num_frames)]
for batch_id in range(0, num_frames, animatediff_stride):
batch_id_ = min(batch_id + animatediff_batch_size, num_frames)
# process this batch
hidden_states_batch = lets_dance(
unet, motion_modules, controlnet,
sample[batch_id: batch_id_].to(device),
timestep,
encoder_hidden_states[batch_id: batch_id_].to(device),
controlnet_frames=controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
device=device, vram_limit_level=vram_limit_level
).cpu()
# update hidden_states
for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch):
bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2)
hidden_states, num = hidden_states_output[i]
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
hidden_states_output[i] = (hidden_states, num + bias)
if batch_id_ == num_frames:
break
# output
hidden_states = torch.stack([h for h, _ in hidden_states_output])
return hidden_states
class SDVideoPipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
super().__init__()
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
self.prompter = SDPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDTextEncoder = None
self.unet: SDUNet = None
self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None
self.motion_modules: SDMotionModel = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.get_model_with_model_path(config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_motion_modules(self, model_manager: ModelManager):
if "motion_modules" in model_manager.model:
self.motion_modules = model_manager.motion_modules
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[]):
pipe = SDVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
use_animatediff="motion_modules" in model_manager.model
)
pipe.fetch_main_models(model_manager)
pipe.fetch_motion_modules(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
latents.append(latent)
latents = torch.concat(latents, dim=0)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
num_frames=None,
input_frames=None,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
smoother=None,
smoother_progress_ids=[],
vram_limit_level=0,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if self.motion_modules is None:
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else:
noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0:
latents = noise
else:
latents = self.encode_images(input_frames)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
# Encode prompts
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True).cpu()
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False).cpu()
prompt_emb_posi = prompt_emb_posi.repeat(num_frames, 1, 1)
prompt_emb_nega = prompt_emb_nega.repeat(num_frames, 1, 1)
# Prepare ControlNets
if controlnet_frames is not None:
if isinstance(controlnet_frames[0], list):
controlnet_frames_ = []
for processor_id in range(len(controlnet_frames)):
controlnet_frames_.append(
torch.stack([
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
], dim=1)
)
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
else:
controlnet_frames = torch.stack([
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
for controlnet_frame in progress_bar_cmd(controlnet_frames)
], dim=1)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred_nega = lets_dance_with_long_video(
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
# DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids:
rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True)
rendered_frames = self.decode_images(rendered_frames)
rendered_frames = smoother(rendered_frames, original_frames=input_frames)
target_latents = self.encode_images(rendered_frames)
noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents)
latents = self.scheduler.step(noise_pred, timestep, latents)
# UI
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
output_frames = self.decode_images(latents)
# Post-process
if smoother is not None and (num_inference_steps in smoother_progress_ids or -1 in smoother_progress_ids):
output_frames = smoother(output_frames, original_frames=input_frames)
return output_frames
class SDVideoPipelineRunner:
def __init__(self, in_streamlit=False):
self.in_streamlit = in_streamlit
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
model_manager.load_textual_inversions(textual_inversion_folder)
model_manager.load_models(model_list, lora_alphas=lora_alphas)
pipe = SDVideoPipeline.from_model_manager(
model_manager,
[
ControlNetConfigUnit(
processor_id=unit["processor_id"],
model_path=unit["model_path"],
scale=unit["scale"]
) for unit in controlnet_units
]
)
return model_manager, pipe
def load_smoother(self, model_manager, smoother_configs):
smoother = SequencialProcessor.from_model_manager(model_manager, smoother_configs)
return smoother
def synthesize_video(self, model_manager, pipe, seed, smoother, **pipeline_inputs):
torch.manual_seed(seed)
if self.in_streamlit:
import streamlit as st
progress_bar_st = st.progress(0.0)
output_video = pipe(**pipeline_inputs, smoother=smoother, progress_bar_st=progress_bar_st)
progress_bar_st.progress(1.0)
else:
output_video = pipe(**pipeline_inputs, smoother=smoother)
model_manager.to("cpu")
return output_video
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
if start_frame_id is None:
start_frame_id = 0
if end_frame_id is None:
end_frame_id = len(video)
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
return frames
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
if len(data["controlnet_frames"]) > 0:
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
return pipeline_inputs
def save_output(self, video, output_folder, fps, config):
os.makedirs(output_folder, exist_ok=True)
save_frames(video, os.path.join(output_folder, "frames"))
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
with open(os.path.join(output_folder, "config.json"), 'w') as file:
json.dump(config, file, indent=4)
def run(self, config):
if self.in_streamlit:
import streamlit as st
if self.in_streamlit: st.markdown("Loading videos ...")
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Loading videos ... done!")
if self.in_streamlit: st.markdown("Loading models ...")
model_manager, pipe = self.load_pipeline(**config["models"])
if self.in_streamlit: st.markdown("Loading models ... done!")
if "smoother_configs" in config:
if self.in_streamlit: st.markdown("Loading smoother ...")
smoother = self.load_smoother(model_manager, config["smoother_configs"])
if self.in_streamlit: st.markdown("Loading smoother ... done!")
else:
smoother = None
if self.in_streamlit: st.markdown("Synthesizing videos ...")
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], smoother, **config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
if self.in_streamlit: st.markdown("Saving videos ...")
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
if self.in_streamlit: st.markdown("Saving videos ... done!")
if self.in_streamlit: st.markdown("Finished!")
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
if self.in_streamlit: st.video(video_file.read())

View File

@@ -1,180 +0,0 @@
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
# TODO: SDXL ControlNet
from ..prompts import SDXLPrompter
from ..schedulers import EnhancedDDIMScheduler
from .dancer import lets_dance_xl
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDXLImagePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.scheduler = EnhancedDDIMScheduler()
self.prompter = SDXLPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
# TODO: SDXL ControlNet
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.text_encoder_2 = model_manager.text_encoder_2
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
# TODO: SDXL ControlNet
pass
def fetch_ipadapter(self, model_manager: ModelManager):
if "ipadapter_xl" in model_manager.model:
self.ipadapter = model_manager.ipadapter_xl
if "ipadapter_xl_image_encoder" in model_manager.model:
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
pipe = SDXLImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
pipe.fetch_ipadapter(model_manager)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=2,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_use_instant_style=False,
controlnet_image=None,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=True,
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=False,
)
# Prepare positional id
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
# IP-Adapter
if ipadapter_images is not None:
if ipadapter_use_instant_style:
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, timestep, latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -1,190 +0,0 @@
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel
from .dancer import lets_dance_xl
# TODO: SDXL ControlNet
from ..prompts import SDXLPrompter
from ..schedulers import EnhancedDDIMScheduler
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class SDXLVideoPipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
super().__init__()
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
self.prompter = SDXLPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
# TODO: SDXL ControlNet
self.motion_modules: SDXLMotionModel = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.text_encoder
self.text_encoder_2 = model_manager.text_encoder_2
self.unet = model_manager.unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
# TODO: SDXL ControlNet
pass
def fetch_motion_modules(self, model_manager: ModelManager):
if "motion_modules_xl" in model_manager.model:
self.motion_modules = model_manager.motion_modules_xl
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
pipe = SDXLVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
use_animatediff="motion_modules_xl" in model_manager.model
)
pipe.fetch_main_models(model_manager)
pipe.fetch_motion_modules(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
images = [
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
for frame_id in range(latents.shape[0])
]
return images
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
latents = []
for image in processed_images:
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
latents.append(latent)
latents = torch.concat(latents, dim=0)
return latents
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=2,
num_frames=None,
input_frames=None,
controlnet_frames=None,
denoising_strength=1.0,
height=512,
width=512,
num_inference_steps=20,
animatediff_batch_size = 16,
animatediff_stride = 8,
unet_batch_size = 1,
controlnet_batch_size = 1,
cross_frame_attention = False,
smoother=None,
smoother_progress_ids=[],
vram_limit_level=0,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if self.motion_modules is None:
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
else:
noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype)
if input_frames is None or denoising_strength == 1.0:
latents = noise
else:
latents = self.encode_images(input_frames)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
# Encode prompts
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=True,
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_2,
negative_prompt,
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
device=self.device,
positive=False,
)
# Prepare positional id
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet, motion_modules=self.motion_modules, controlnet=None,
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet, motion_modules=self.motion_modules, controlnet=None,
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, timestep, latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_images(latents.to(torch.float32))
return image

View File

@@ -1,5 +1,6 @@
from ..models import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, SVDVAEDecoder
from ..schedulers import ContinuousODEScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
from PIL import Image
@@ -8,13 +9,11 @@ from einops import rearrange, repeat
class SVDVideoPipeline(torch.nn.Module):
class SVDVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = ContinuousODEScheduler()
self.device = device
self.torch_dtype = torch_dtype
# models
self.image_encoder: SVDImageEncoder = None
self.unet: SVDUNet = None
@@ -22,32 +21,23 @@ class SVDVideoPipeline(torch.nn.Module):
self.vae_decoder: SVDVAEDecoder = None
def fetch_main_models(self, model_manager: ModelManager):
self.image_encoder = model_manager.image_encoder
self.unet = model_manager.unet
self.vae_encoder = model_manager.vae_encoder
self.vae_decoder = model_manager.vae_decoder
def fetch_models(self, model_manager: ModelManager):
self.image_encoder = model_manager.fetch_model("svd_image_encoder")
self.unet = model_manager.fetch_model("svd_unet")
self.vae_encoder = model_manager.fetch_model("svd_vae_encoder")
self.vae_decoder = model_manager.fetch_model("svd_vae_decoder")
@staticmethod
def from_model_manager(model_manager: ModelManager, **kwargs):
pipe = SVDVideoPipeline(device=model_manager.device, torch_dtype=model_manager.torch_dtype)
pipe.fetch_main_models(model_manager)
pipe = SVDVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype
)
pipe.fetch_models(model_manager)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def encode_image_with_clip(self, image):
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
image = SVDCLIPImageProcessor().resize_with_antialiasing(image, (224, 224))

View File

@@ -1,3 +1,4 @@
from .prompt_refiners import Translator, BeautifulPrompt
from .sd_prompter import SDPrompter
from .sdxl_prompter import SDXLPrompter
from .sd3_prompter import SD3Prompter

View File

@@ -0,0 +1,57 @@
from ..models.model_manager import ModelManager
import torch
def tokenize_long_prompt(tokenizer, prompt, max_length=None):
# Get model_max_length from self.tokenizer
length = tokenizer.model_max_length if max_length is None else max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore tokenizer.model_max_length
tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids
class BasePrompter:
def __init__(self, refiners=[]):
self.refiners = refiners
def load_prompt_refiners(self, model_nameger: ModelManager, refiner_classes=[]):
for refiner_class in refiner_classes:
refiner = refiner_class.from_model_manager(model_nameger)
self.refiners.append(refiner)
@torch.no_grad()
def process_prompt(self, prompt, positive=True):
if isinstance(prompt, list):
prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt]
else:
for refiner in self.refiners:
prompt = refiner(prompt, positive=positive)
return prompt

View File

@@ -1,9 +1,11 @@
from .utils import Prompter
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
from .base_prompter import BasePrompter
from ..models.model_manager import ModelManager
from ..models import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from transformers import BertTokenizer, AutoTokenizer
import warnings, os
class HunyuanDiTPrompter(Prompter):
class HunyuanDiTPrompter(BasePrompter):
def __init__(
self,
tokenizer_path=None,
@@ -20,6 +22,13 @@ class HunyuanDiTPrompter(Prompter):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
def fetch_models(self, text_encoder: HunyuanDiTCLIPTextEncoder = None, text_encoder_t5: HunyuanDiTT5TextEncoder = None):
self.text_encoder = text_encoder
self.text_encoder_t5 = text_encoder_t5
def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
@@ -43,8 +52,6 @@ class HunyuanDiTPrompter(Prompter):
def encode_prompt(
self,
text_encoder: BertModel,
text_encoder_t5: T5EncoderModel,
prompt,
clip_skip=1,
clip_skip_2=1,
@@ -54,9 +61,9 @@ class HunyuanDiTPrompter(Prompter):
prompt = self.process_prompt(prompt, positive=positive)
# CLIP
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, self.text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
# T5
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, self.text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5

View File

@@ -1,4 +1,5 @@
from .utils import Prompter
from .base_prompter import BasePrompter
from ..models.model_manager import ModelManager
import json, os, re
from typing import List, Optional, Union, Dict
from sentencepiece import SentencePieceProcessor
@@ -302,7 +303,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
class KolorsPrompter(Prompter):
class KolorsPrompter(BasePrompter):
def __init__(
self,
tokenizer_path=None
@@ -312,6 +313,11 @@ class KolorsPrompter(Prompter):
tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
super().__init__()
self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
self.text_encoder: ChatGLMModel = None
def fetch_models(self, text_encoder: ChatGLMModel = None):
self.text_encoder = text_encoder
def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
@@ -335,13 +341,13 @@ class KolorsPrompter(Prompter):
def encode_prompt(
self,
text_encoder: ChatGLMModel,
prompt,
clip_skip=2,
clip_skip=1,
clip_skip_2=2,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, text_encoder, self.tokenizer, 256, clip_skip, device)
prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, self.text_encoder, self.tokenizer, 256, clip_skip_2, device)
return pooled_prompt_emb, prompt_emb

View File

@@ -0,0 +1,77 @@
from transformers import AutoTokenizer
from ..models.model_manager import ModelManager
import torch
class BeautifulPrompt(torch.nn.Module):
def __init__(self, tokenizer_path=None, model=None, template=""):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.template = template
@staticmethod
def from_model_manager(model_nameger: ModelManager):
model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True)
template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
if model_path.endswith("v2"):
template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
beautiful_prompt = BeautifulPrompt(
tokenizer_path=model_path,
model=model,
template=template
)
return beautiful_prompt
def __call__(self, raw_prompt, positive=True, **kwargs):
if positive:
model_input = self.template.format(raw_prompt=raw_prompt)
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=384,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1
)
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
outputs[:, input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
return prompt
else:
return raw_prompt
class Translator(torch.nn.Module):
def __init__(self, tokenizer_path=None, model=None):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
@staticmethod
def from_model_manager(model_nameger: ModelManager):
model, model_path = model_nameger.fetch_model("translator", require_model_path=True)
translator = Translator(tokenizer_path=model_path, model=model)
return translator
def __call__(self, prompt, **kwargs):
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
output_ids = self.model.generate(input_ids)
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
print(f"Your prompt is translated: {prompt}")
return prompt

View File

@@ -1,9 +1,11 @@
from .utils import Prompter
from .base_prompter import BasePrompter
from ..models.model_manager import ModelManager
from ..models import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
from transformers import CLIPTokenizer, T5TokenizerFast
import os, torch
class SD3Prompter(Prompter):
class SD3Prompter(BasePrompter):
def __init__(
self,
tokenizer_1_path=None,
@@ -20,9 +22,18 @@ class SD3Prompter(Prompter):
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: SD3TextEncoder2 = None
self.text_encoder_3: SD3TextEncoder3 = None
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: SD3TextEncoder2 = None, text_encoder_3: SD3TextEncoder3 = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
self.text_encoder_3 = text_encoder_3
def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
@@ -54,24 +65,21 @@ class SD3Prompter(Prompter):
def encode_prompt(
self,
text_encoder_1,
text_encoder_2,
text_encoder_3,
prompt,
positive=True,
device="cuda"
):
prompt, pure_prompt = self.process_prompt(prompt, positive=positive, require_pure_prompt=True)
prompt = self.process_prompt(prompt, positive=positive)
# CLIP
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, text_encoder_1, self.tokenizer, 77, device)
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(pure_prompt, text_encoder_2, self.tokenizer_2, 77, device)
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, self.text_encoder_2, self.tokenizer_2, 77, device)
# T5
if text_encoder_3 is None:
if self.text_encoder_3 is None:
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
else:
prompt_emb_3 = self.encode_prompt_using_t5(pure_prompt, text_encoder_3, self.tokenizer_3, 256, device)
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device)
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
# Merge

View File

@@ -0,0 +1,73 @@
from .base_prompter import BasePrompter, tokenize_long_prompt
from ..models.model_manager import ModelManager, load_state_dict, search_for_embeddings
from ..models import SDTextEncoder
from transformers import CLIPTokenizer
import torch, os
class SDPrompter(BasePrompter):
def __init__(self, tokenizer_path=None):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.text_encoder: SDTextEncoder = None
self.textual_inversion_dict = {}
self.keyword_dict = {}
def fetch_models(self, text_encoder: SDTextEncoder = None):
self.text_encoder = text_encoder
def add_textual_inversions_to_model(self, textual_inversion_dict, text_encoder):
dtype = next(iter(text_encoder.parameters())).dtype
state_dict = text_encoder.token_embedding.state_dict()
token_embeddings = [state_dict["weight"]]
for keyword in textual_inversion_dict:
_, embeddings = textual_inversion_dict[keyword]
token_embeddings.append(embeddings.to(dtype=dtype, device=token_embeddings[0].device))
token_embeddings = torch.concat(token_embeddings, dim=0)
state_dict["weight"] = token_embeddings
text_encoder.token_embedding = torch.nn.Embedding(token_embeddings.shape[0], token_embeddings.shape[1])
text_encoder.token_embedding = text_encoder.token_embedding.to(dtype=dtype, device=token_embeddings[0].device)
text_encoder.token_embedding.load_state_dict(state_dict)
def add_textual_inversions_to_tokenizer(self, textual_inversion_dict, tokenizer):
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
tokenizer.add_tokens(additional_tokens)
def load_textual_inversions(self, model_paths):
for model_path in model_paths:
keyword = os.path.splitext(os.path.split(model_path)[-1])[0]
state_dict = load_state_dict(model_path)
# Search for embeddings
for embeddings in search_for_embeddings(state_dict):
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
self.textual_inversion_dict[keyword] = (tokens, embeddings)
self.add_textual_inversions_to_model(self.textual_inversion_dict, self.text_encoder)
self.add_textual_inversions_to_tokenizer(self.textual_inversion_dict, self.tokenizer)
def encode_prompt(self, prompt, clip_skip=1, device="cuda", positive=True):
prompt = self.process_prompt(prompt, positive=positive)
for keyword in self.keyword_dict:
if keyword in prompt:
print(f"Textual inversion {keyword} is enabled.")
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb

View File

@@ -1,10 +1,12 @@
from .utils import Prompter, tokenize_long_prompt
from transformers import CLIPTokenizer
from .base_prompter import BasePrompter, tokenize_long_prompt
from ..models.model_manager import ModelManager
from ..models import SDXLTextEncoder, SDXLTextEncoder2
from transformers import CLIPTokenizer
import torch, os
class SDXLPrompter(Prompter):
class SDXLPrompter(BasePrompter):
def __init__(
self,
tokenizer_path=None,
@@ -19,11 +21,17 @@ class SDXLPrompter(Prompter):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
self.text_encoder: SDXLTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
def fetch_models(self, text_encoder: SDXLTextEncoder = None, text_encoder_2: SDXLTextEncoder2 = None):
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
def encode_prompt(
self,
text_encoder: SDXLTextEncoder,
text_encoder_2: SDXLTextEncoder2,
prompt,
clip_skip=1,
clip_skip_2=2,
@@ -34,11 +42,11 @@ class SDXLPrompter(Prompter):
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb_1 = self.text_encoder(input_ids, clip_skip=clip_skip)
# 2
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
add_text_embeds, prompt_emb_2 = self.text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
# Merge
if prompt_emb_1.shape[0] != prompt_emb_2.shape[0]:

View File

@@ -1,21 +0,0 @@
from .utils import Prompter, tokenize_long_prompt
from transformers import CLIPTokenizer
from ..models import SDTextEncoder
import os
class SDPrompter(Prompter):
def __init__(self, tokenizer_path=None):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
prompt = self.process_prompt(prompt, positive=positive)
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb

View File

@@ -1,144 +0,0 @@
from transformers import CLIPTokenizer, AutoTokenizer
from ..models import ModelManager
import os
def tokenize_long_prompt(tokenizer, prompt):
# Get model_max_length from self.tokenizer
length = tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore tokenizer.model_max_length
tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids
class BeautifulPrompt:
def __init__(self, tokenizer_path=None, model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
def __call__(self, raw_prompt):
model_input = self.template.format(raw_prompt=raw_prompt)
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=384,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1
)
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
outputs[:, input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
return prompt
class Translator:
def __init__(self, tokenizer_path=None, model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
def __call__(self, prompt):
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
output_ids = self.model.generate(input_ids)
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
return prompt
class Prompter:
def __init__(self):
self.tokenizer: CLIPTokenizer = None
self.keyword_dict = {}
self.translator: Translator = None
self.beautiful_prompt: BeautifulPrompt = None
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
if self.tokenizer is not None:
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
def load_translator(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.translator = Translator(tokenizer_path=model_folder, model=model)
def load_from_model_manager(self, model_manager: ModelManager):
self.load_textual_inversion(model_manager.textual_inversion_dict)
if "translator" in model_manager.model:
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
if "beautiful_prompt" in model_manager.model:
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def add_textual_inversion_tokens(self, prompt):
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
return prompt
def del_textual_inversion_tokens(self, prompt):
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, "")
return prompt
def process_prompt(self, prompt, positive=True, require_pure_prompt=False):
if isinstance(prompt, list):
prompt = [self.process_prompt(prompt_, positive=positive, require_pure_prompt=require_pure_prompt) for prompt_ in prompt]
if require_pure_prompt:
prompt, pure_prompt = [i[0] for i in prompt], [i[1] for i in prompt]
return prompt, pure_prompt
else:
return prompt
prompt, pure_prompt = self.add_textual_inversion_tokens(prompt), self.del_textual_inversion_tokens(prompt)
if positive and self.translator is not None:
prompt = self.translator(prompt)
print(f"Your prompt is translated: \"{prompt}\"")
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
if require_pure_prompt:
return prompt, pure_prompt
else:
return prompt

View File

@@ -22,10 +22,10 @@ class EnhancedDDIMScheduler():
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
num_inference_steps = min(num_inference_steps, max_timestep + 1)
if num_inference_steps == 1:
self.timesteps = [max_timestep]
self.timesteps = torch.Tensor([max_timestep])
else:
step_length = max_timestep / (num_inference_steps - 1)
self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)]
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
@@ -43,31 +43,37 @@ class EnhancedDDIMScheduler():
def step(self, model_output, timestep, sample, to_final=False):
alpha_prod_t = self.alphas_cumprod[timestep]
timestep_id = self.timesteps.index(timestep)
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
if to_final or timestep_id + 1 >= len(self.timesteps):
alpha_prod_t_prev = 1.0
else:
timestep_prev = self.timesteps[timestep_id + 1]
timestep_prev = int(self.timesteps[timestep_id + 1])
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
def return_to_timestep(self, timestep, sample, sample_stablized):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
return noise_pred
def add_noise(self, original_samples, noise, timestep):
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def training_target(self, sample, noise, timestep):
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return target
if self.prediction_type == "epsilon":
return noise
else:
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return target

View File

@@ -20,6 +20,8 @@ class FlowMatchScheduler():
def step(self, model_output, timestep, sample, to_final=False):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
if to_final or timestep_id + 1 >= len(self.timesteps):
@@ -36,6 +38,8 @@ class FlowMatchScheduler():
def add_noise(self, original_samples, noise, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
sample = (1 - sigma) * original_samples + sigma * noise

View File

@@ -0,0 +1,253 @@
import lightning as pl
from peft import LoraConfig, inject_adapter_in_model
import torch, os
from ..data.simple_text_image import TextImageDataset
from modelscope.hub.api import HubApi
class LightningModelForT2ILoRA(pl.LightningModule):
def __init__(
self,
learning_rate=1e-4,
use_gradient_checkpointing=True,
):
super().__init__()
# Set parameters
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
def load_models(self):
# This function is implemented in other modules
self.pipe = None
def freeze_parameters(self):
# Freeze parameters
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.denoising_model().train()
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"):
# Add LoRA to UNet
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights="gaussian",
target_modules=lora_target_modules.split(","),
)
model = inject_adapter_in_model(lora_config, model)
for param in model.parameters():
# Upcast LoRA parameters into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
def training_step(self, batch, batch_idx):
# Data
text, image = batch["text"], batch["image"]
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
noise_pred = self.pipe.denoising_model()(
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred, training_target)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.denoising_model().state_dict()
for name, param in state_dict.items():
if name in trainable_param_names:
checkpoint[name] = param
def add_general_parsers(parser):
parser.add_argument(
"--dataset_path",
type=str,
default=None,
required=True,
help="The path of the Dataset.",
)
parser.add_argument(
"--output_path",
type=str,
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=500,
help="Number of steps per epoch.",
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="Image height.",
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="Image width.",
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
default=False,
action="store_true",
help="Whether to randomly flip images horizontally",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--precision",
type=str,
default="16-mixed",
choices=["32", "16", "16-mixed"],
help="Training precision",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate.",
)
parser.add_argument(
"--lora_rank",
type=int,
default=4,
help="The dimension of the LoRA update matrices.",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=4.0,
help="The weight of the LoRA update matrices.",
)
parser.add_argument(
"--use_gradient_checkpointing",
default=False,
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="The number of batches in gradient accumulation.",
)
parser.add_argument(
"--training_strategy",
type=str,
default="auto",
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
help="Training strategy",
)
parser.add_argument(
"--max_epochs",
type=int,
default=1,
help="Number of epochs.",
)
parser.add_argument(
"--modelscope_model_id",
type=str,
default=None,
help="Model ID on ModelScope (https://www.modelscope.cn/). The model will be uploaded to ModelScope automatically if you provide a Model ID.",
)
parser.add_argument(
"--modelscope_access_token",
type=str,
default=None,
help="Access key on ModelScope (https://www.modelscope.cn/). Required if you want to upload the model to ModelScope.",
)
return parser
def launch_training_task(model, args):
# dataset and data loader
dataset = TextImageDataset(
args.dataset_path,
steps_per_epoch=args.steps_per_epoch * args.batch_size,
height=args.height,
width=args.width,
center_crop=args.center_crop,
random_flip=args.random_flip
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers
)
# train
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision=args.precision,
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
)
trainer.fit(model=model, train_dataloaders=train_loader)
# Upload models
if args.modelscope_model_id is not None and args.modelscope_access_token is not None:
print(f"Uploading models to modelscope. model_id: {args.modelscope_model_id} local_path: {trainer.log_dir}")
with open(os.path.join(trainer.log_dir, "configuration.json"), "w", encoding="utf-8") as f:
f.write('{"framework":"Pytorch","task":"text-to-image-synthesis"}\n')
api = HubApi()
api.login(args.modelscope_access_token)
api.push_model(model_id=args.modelscope_model_id, model_dir=trainer.log_dir)