ipadapter

This commit is contained in:
Artiprocher
2024-06-09 15:26:44 +08:00
parent 84744127f6
commit fe3870fa14
7 changed files with 118 additions and 9 deletions

View File

@@ -22,7 +22,8 @@ from .svd_unet import SVDUNet
from .svd_vae_decoder import SVDVAEDecoder from .svd_vae_decoder import SVDVAEDecoder
from .svd_vae_encoder import SVDVAEEncoder from .svd_vae_encoder import SVDVAEEncoder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterCLIPImageEmbedder from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT from .hunyuan_dit import HunyuanDiT
@@ -79,12 +80,19 @@ class ModelManager:
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight" param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
return param_name in state_dict and len(state_dict) == 254 return param_name in state_dict and len(state_dict) == 254
def is_ipadapter(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
def is_ipadapter_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
return param_name in state_dict and len(state_dict) == 521
def is_ipadapter_xl(self, state_dict): def is_ipadapter_xl(self, state_dict):
return "image_proj" in state_dict and "ip_adapter" in state_dict return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
def is_ipadapter_xl_image_encoder(self, state_dict): def is_ipadapter_xl_image_encoder(self, state_dict):
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight" param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
return param_name in state_dict return param_name in state_dict and len(state_dict) == 777
def is_hunyuan_dit_clip_text_encoder(self, state_dict): def is_hunyuan_dit_clip_text_encoder(self, state_dict):
param_name = "bert.encoder.layer.23.attention.output.dense.weight" param_name = "bert.encoder.layer.23.attention.output.dense.weight"
@@ -226,6 +234,22 @@ class ModelManager:
self.model[component] = model self.model[component] = model
self.model_path[component] = file_path self.model_path[component] = file_path
def load_ipadapter(self, state_dict, file_path=""):
component = "ipadapter"
model = SDIpAdapter()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_image_encoder"
model = IpAdapterCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_ipadapter_xl(self, state_dict, file_path=""): def load_ipadapter_xl(self, state_dict, file_path=""):
component = "ipadapter_xl" component = "ipadapter_xl"
model = SDXLIpAdapter() model = SDXLIpAdapter()
@@ -236,7 +260,7 @@ class ModelManager:
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""): def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
component = "ipadapter_xl_image_encoder" component = "ipadapter_xl_image_encoder"
model = IpAdapterCLIPImageEmbedder() model = IpAdapterXLCLIPImageEmbedder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device) model.to(self.torch_dtype).to(self.device)
self.model[component] = model self.model[component] = model
@@ -330,6 +354,10 @@ class ModelManager:
self.load_RIFE(state_dict, file_path=file_path) self.load_RIFE(state_dict, file_path=file_path)
elif self.is_translator(state_dict): elif self.is_translator(state_dict):
self.load_translator(state_dict, file_path=file_path) self.load_translator(state_dict, file_path=file_path)
elif self.is_ipadapter(state_dict):
self.load_ipadapter(state_dict, file_path=file_path)
elif self.is_ipadapter_image_encoder(state_dict):
self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
elif self.is_ipadapter_xl(state_dict): elif self.is_ipadapter_xl(state_dict):
self.load_ipadapter_xl(state_dict, file_path=file_path) self.load_ipadapter_xl(state_dict, file_path=file_path)
elif self.is_ipadapter_xl_image_encoder(state_dict): elif self.is_ipadapter_xl_image_encoder(state_dict):

View File

@@ -0,0 +1,56 @@
from .svd_image_encoder import SVDImageEncoder
from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
from transformers import CLIPImageProcessor
import torch
class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
def __init__(self):
super().__init__()
self.image_processor = CLIPImageProcessor()
def forward(self, image):
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
return super().forward(pixel_values)
class SDIpAdapter(torch.nn.Module):
def __init__(self):
super().__init__()
shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
self.set_full_adapter()
def set_full_adapter(self):
block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
def set_less_adapter(self):
# IP-Adapter for SD v1.5 doesn't support this feature.
self.set_full_adapter(self)
def forward(self, hidden_states, scale=1.0):
hidden_states = self.image_proj(hidden_states)
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
ip_kv_dict = {}
for (block_id, transformer_id) in self.call_block_id:
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
if block_id not in ip_kv_dict:
ip_kv_dict[block_id] = {}
ip_kv_dict[block_id][transformer_id] = {
"ip_k": ip_k,
"ip_v": ip_v,
"scale": scale
}
return ip_kv_dict
def state_dict_converter(self):
return SDIpAdapterStateDictConverter()
class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
def __init__(self):
pass

View File

@@ -3,7 +3,7 @@ from transformers import CLIPImageProcessor
import torch import torch
class IpAdapterCLIPImageEmbedder(SVDImageEncoder): class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
def __init__(self): def __init__(self):
super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104) super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
self.image_processor = CLIPImageProcessor() self.image_processor = CLIPImageProcessor()

View File

@@ -11,6 +11,7 @@ def lets_dance(
sample = None, sample = None,
timestep = None, timestep = None,
encoder_hidden_states = None, encoder_hidden_states = None,
ipadapter_kwargs_list = {},
controlnet_frames = None, controlnet_frames = None,
unet_batch_size = 1, unet_batch_size = 1,
controlnet_batch_size = 1, controlnet_batch_size = 1,
@@ -80,6 +81,7 @@ def lets_dance(
text_emb[batch_id: batch_id_], text_emb[batch_id: batch_id_],
res_stack, res_stack,
cross_frame_attention=cross_frame_attention, cross_frame_attention=cross_frame_attention,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
) )
hidden_states_output.append(hidden_states) hidden_states_output.append(hidden_states)

View File

@@ -1,4 +1,4 @@
from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder, SDIpAdapter, IpAdapterCLIPImageEmbedder
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompts import SDPrompter from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler from ..schedulers import EnhancedDDIMScheduler
@@ -24,6 +24,8 @@ class SDImagePipeline(torch.nn.Module):
self.vae_decoder: SDVAEDecoder = None self.vae_decoder: SDVAEDecoder = None
self.vae_encoder: SDVAEEncoder = None self.vae_encoder: SDVAEEncoder = None
self.controlnet: MultiControlNetManager = None self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
self.ipadapter: SDIpAdapter = None
def fetch_main_models(self, model_manager: ModelManager): def fetch_main_models(self, model_manager: ModelManager):
@@ -44,6 +46,13 @@ class SDImagePipeline(torch.nn.Module):
controlnet_units.append(controlnet_unit) controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units) self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_ipadapter(self, model_manager: ModelManager):
if "ipadapter" in model_manager.model:
self.ipadapter = model_manager.ipadapter
if "ipadapter_image_encoder" in model_manager.model:
self.ipadapter_image_encoder = model_manager.ipadapter_image_encoder
def fetch_prompter(self, model_manager: ModelManager): def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager) self.prompter.load_from_model_manager(model_manager)
@@ -58,6 +67,7 @@ class SDImagePipeline(torch.nn.Module):
pipe.fetch_main_models(model_manager) pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager) pipe.fetch_prompter(model_manager)
pipe.fetch_controlnet_models(model_manager, controlnet_config_units) pipe.fetch_controlnet_models(model_manager, controlnet_config_units)
pipe.fetch_ipadapter(model_manager)
return pipe return pipe
@@ -81,6 +91,8 @@ class SDImagePipeline(torch.nn.Module):
cfg_scale=7.5, cfg_scale=7.5,
clip_skip=1, clip_skip=1,
input_image=None, input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None, controlnet_image=None,
denoising_strength=1.0, denoising_strength=1.0,
height=512, height=512,
@@ -108,6 +120,14 @@ class SDImagePipeline(torch.nn.Module):
prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True) prompt_emb_posi = self.prompter.encode_prompt(self.text_encoder, prompt, clip_skip=clip_skip, device=self.device, positive=True)
prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False) prompt_emb_nega = self.prompter.encode_prompt(self.text_encoder, negative_prompt, clip_skip=clip_skip, device=self.device, positive=False)
# IP-Adapter
if ipadapter_images is not None:
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
# Prepare ControlNets # Prepare ControlNets
if controlnet_image is not None: if controlnet_image is not None:
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype) controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
@@ -122,12 +142,14 @@ class SDImagePipeline(torch.nn.Module):
self.unet, motion_modules=None, controlnet=self.controlnet, self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image, sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_image,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
device=self.device, vram_limit_level=0 device=self.device, vram_limit_level=0
) )
noise_pred_nega = lets_dance( noise_pred_nega = lets_dance(
self.unet, motion_modules=None, controlnet=self.controlnet, self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image, sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_image,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
device=self.device, vram_limit_level=0 device=self.device, vram_limit_level=0
) )
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)

View File

@@ -1,4 +1,4 @@
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterCLIPImageEmbedder from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
# TODO: SDXL ControlNet # TODO: SDXL ControlNet
from ..prompts import SDXLPrompter from ..prompts import SDXLPrompter
from ..schedulers import EnhancedDDIMScheduler from ..schedulers import EnhancedDDIMScheduler
@@ -23,7 +23,7 @@ class SDXLImagePipeline(torch.nn.Module):
self.unet: SDXLUNet = None self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None self.vae_encoder: SDXLVAEEncoder = None
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None self.ipadapter: SDXLIpAdapter = None
# TODO: SDXL ControlNet # TODO: SDXL ControlNet
@@ -86,6 +86,7 @@ class SDXLImagePipeline(torch.nn.Module):
clip_skip_2=2, clip_skip_2=2,
input_image=None, input_image=None,
ipadapter_images=None, ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None, controlnet_image=None,
denoising_strength=1.0, denoising_strength=1.0,
height=1024, height=1024,
@@ -134,7 +135,7 @@ class SDXLImagePipeline(torch.nn.Module):
# IP-Adapter # IP-Adapter
if ipadapter_images is not None: if ipadapter_images is not None:
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding) ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding)) ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
else: else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {} ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}