ipadapter for sdxl

This commit is contained in:
Artiprocher
2024-05-14 23:24:24 +08:00
parent 3b5bbb5773
commit 83461d400c
8 changed files with 251 additions and 27 deletions

View File

@@ -22,6 +22,8 @@ from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterCLIPImageEmbedder
class ModelManager:
def __init__(self, torch_dtype=torch.float16, device="cuda"):
@@ -74,6 +76,13 @@ class ModelManager:
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
return param_name in state_dict and len(state_dict) == 254
def is_ipadapter_xl(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict
def is_ipadapter_xl_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
return param_name in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"image_encoder": SVDImageEncoder,
@@ -198,6 +207,22 @@ class ModelManager:
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl(self, state_dict, file_path=""):
component = "ipadapter_xl"
model = SDXLIpAdapter()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_xl_image_encoder"
model = IpAdapterCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
@@ -247,6 +272,10 @@ class ModelManager:
self.load_RIFE(state_dict, file_path=file_path)
elif self.is_translator(state_dict):
self.load_translator(state_dict, file_path=file_path)
elif self.is_ipadapter_xl(state_dict):
self.load_ipadapter_xl(state_dict, file_path=file_path)
elif self.is_ipadapter_xl_image_encoder(state_dict):
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:
@@ -299,7 +328,9 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
state_dict = {i: state_dict[i].to(torch_dtype) for i in state_dict}
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict