mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
ipadapter for sdxl
This commit is contained in:
@@ -22,6 +22,8 @@ from .svd_unet import SVDUNet
|
||||
from .svd_vae_decoder import SVDVAEDecoder
|
||||
from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
@@ -74,6 +76,13 @@ class ModelManager:
|
||||
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
||||
return param_name in state_dict and len(state_dict) == 254
|
||||
|
||||
def is_ipadapter_xl(self, state_dict):
|
||||
return "image_proj" in state_dict and "ip_adapter" in state_dict
|
||||
|
||||
def is_ipadapter_xl_image_encoder(self, state_dict):
|
||||
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
@@ -198,6 +207,22 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_xl(self, state_dict, file_path=""):
|
||||
component = "ipadapter_xl"
|
||||
model = SDXLIpAdapter()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
|
||||
component = "ipadapter_xl_image_encoder"
|
||||
model = IpAdapterCLIPImageEmbedder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -247,6 +272,10 @@ class ModelManager:
|
||||
self.load_RIFE(state_dict, file_path=file_path)
|
||||
elif self.is_translator(state_dict):
|
||||
self.load_translator(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_xl(state_dict):
|
||||
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
||||
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
@@ -299,7 +328,9 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
if torch_dtype is not None:
|
||||
state_dict = {i: state_dict[i].to(torch_dtype) for i in state_dict}
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user