From 83461d400cd8c92dadead34842d6748e5ece8822 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 14 May 2024 23:24:24 +0800 Subject: [PATCH] ipadapter for sdxl --- diffsynth/models/__init__.py | 33 +++++- diffsynth/models/attention.py | 16 ++- diffsynth/models/sd_unet.py | 12 +- diffsynth/models/sdxl_ipadapter.py | 121 +++++++++++++++++++++ diffsynth/models/svd_image_encoder.py | 8 +- diffsynth/pipelines/dancer.py | 4 +- diffsynth/pipelines/stable_diffusion_xl.py | 48 +++++--- examples/sdxl_ipadapter.py | 36 ++++++ 8 files changed, 251 insertions(+), 27 deletions(-) create mode 100644 diffsynth/models/sdxl_ipadapter.py create mode 100644 examples/sdxl_ipadapter.py diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index 9f90505..f2ccfae 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -22,6 +22,8 @@ from .svd_unet import SVDUNet from .svd_vae_decoder import SVDVAEDecoder from .svd_vae_encoder import SVDVAEEncoder +from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterCLIPImageEmbedder + class ModelManager: def __init__(self, torch_dtype=torch.float16, device="cuda"): @@ -74,6 +76,13 @@ class ModelManager: param_name = "model.encoder.layers.5.self_attn_layer_norm.weight" return param_name in state_dict and len(state_dict) == 254 + def is_ipadapter_xl(self, state_dict): + return "image_proj" in state_dict and "ip_adapter" in state_dict + + def is_ipadapter_xl_image_encoder(self, state_dict): + param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight" + return param_name in state_dict + def load_stable_video_diffusion(self, state_dict, components=None, file_path=""): component_dict = { "image_encoder": SVDImageEncoder, @@ -198,6 +207,22 @@ class ModelManager: self.model[component] = model self.model_path[component] = file_path + def load_ipadapter_xl(self, state_dict, file_path=""): + component = "ipadapter_xl" + model = SDXLIpAdapter() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""): + component = "ipadapter_xl_image_encoder" + model = IpAdapterCLIPImageEmbedder() + model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + def search_for_embeddings(self, state_dict): embeddings = [] for k in state_dict: @@ -247,6 +272,10 @@ class ModelManager: self.load_RIFE(state_dict, file_path=file_path) elif self.is_translator(state_dict): self.load_translator(state_dict, file_path=file_path) + elif self.is_ipadapter_xl(state_dict): + self.load_ipadapter_xl(state_dict, file_path=file_path) + elif self.is_ipadapter_xl_image_encoder(state_dict): + self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path) def load_models(self, file_path_list, lora_alphas=[]): for file_path in file_path_list: @@ -299,7 +328,9 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None): def load_state_dict_from_bin(file_path, torch_dtype=None): state_dict = torch.load(file_path, map_location="cpu") if torch_dtype is not None: - state_dict = {i: state_dict[i].to(torch_dtype) for i in state_dict} + for i in state_dict: + if isinstance(state_dict[i], torch.Tensor): + state_dict[i] = state_dict[i].to(torch_dtype) return state_dict diff --git a/diffsynth/models/attention.py b/diffsynth/models/attention.py index 5961c11..4d8c3d9 100644 --- a/diffsynth/models/attention.py +++ b/diffsynth/models/attention.py @@ -26,7 +26,15 @@ class Attention(torch.nn.Module): self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) - def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None): if encoder_hidden_states is None: encoder_hidden_states = hidden_states @@ -41,6 +49,8 @@ class Attention(torch.nn.Module): v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) @@ -72,5 +82,5 @@ class Attention(torch.nn.Module): return hidden_states - def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): - return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask) \ No newline at end of file + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs) \ No newline at end of file diff --git a/diffsynth/models/sd_unet.py b/diffsynth/models/sd_unet.py index 6c99ae4..8a8b17e 100644 --- a/diffsynth/models/sd_unet.py +++ b/diffsynth/models/sd_unet.py @@ -47,15 +47,15 @@ class BasicTransformerBlock(torch.nn.Module): self.ff = torch.nn.Linear(dim * 4, dim) - def forward(self, hidden_states, encoder_hidden_states): + def forward(self, hidden_states, encoder_hidden_states, ipadapter_kwargs=None): # 1. Self-Attention norm_hidden_states = self.norm1(hidden_states) - attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None,) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) hidden_states = attn_output + hidden_states # 2. Cross-Attention norm_hidden_states = self.norm2(hidden_states) - attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, ipadapter_kwargs=ipadapter_kwargs) hidden_states = attn_output + hidden_states # 3. Feed-forward @@ -150,6 +150,7 @@ class AttentionBlock(torch.nn.Module): hidden_states, time_emb, text_emb, res_stack, cross_frame_attention=False, tiled=False, tile_size=64, tile_stride=32, + ipadapter_kwargs_list={}, **kwargs ): batch, _, height, width = hidden_states.shape @@ -188,10 +189,11 @@ class AttentionBlock(torch.nn.Module): ) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: - for block in self.transformer_blocks: + for block_id, block in enumerate(self.transformer_blocks): hidden_states = block( hidden_states, - encoder_hidden_states=encoder_hidden_states + encoder_hidden_states=encoder_hidden_states, + ipadapter_kwargs=ipadapter_kwargs_list.get(block_id, None) ) if cross_frame_attention: hidden_states = hidden_states.reshape(batch, height * width, inner_dim) diff --git a/diffsynth/models/sdxl_ipadapter.py b/diffsynth/models/sdxl_ipadapter.py new file mode 100644 index 0000000..ab251e8 --- /dev/null +++ b/diffsynth/models/sdxl_ipadapter.py @@ -0,0 +1,121 @@ +from .svd_image_encoder import SVDImageEncoder +from transformers import CLIPImageProcessor +import torch + + +class IpAdapterCLIPImageEmbedder(SVDImageEncoder): + def __init__(self): + super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104) + self.image_processor = CLIPImageProcessor() + + def forward(self, image): + pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values + pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype) + return super().forward(pixel_values) + + +class IpAdapterImageProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4): + super().__init__() + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class IpAdapterModule(torch.nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + + def forward(self, hidden_states): + ip_k = self.to_k_ip(hidden_states) + ip_v = self.to_v_ip(hidden_states) + return ip_k, ip_v + + +class SDXLIpAdapter(torch.nn.Module): + def __init__(self): + super().__init__() + shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10 + self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list]) + self.image_proj = IpAdapterImageProjModel() + self.set_full_adapter() + + def set_full_adapter(self): + map_list = sum([ + [(7, i) for i in range(2)], + [(10, i) for i in range(2)], + [(15, i) for i in range(10)], + [(18, i) for i in range(10)], + [(25, i) for i in range(10)], + [(28, i) for i in range(10)], + [(31, i) for i in range(10)], + [(35, i) for i in range(2)], + [(38, i) for i in range(2)], + [(41, i) for i in range(2)], + [(21, i) for i in range(10)], + ], []) + self.call_block_id = {i: j for j, i in enumerate(map_list)} + + def set_less_adapter(self): + map_list = sum([ + [(7, i) for i in range(2)], + [(10, i) for i in range(2)], + [(15, i) for i in range(10)], + [(18, i) for i in range(10)], + [(25, i) for i in range(10)], + [(28, i) for i in range(10)], + [(31, i) for i in range(10)], + [(35, i) for i in range(2)], + [(38, i) for i in range(2)], + [(41, i) for i in range(2)], + [(21, i) for i in range(10)], + ], []) + self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44} + + def forward(self, hidden_states, scale=1.0): + hidden_states = self.image_proj(hidden_states) + hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) + ip_kv_dict = {} + for (block_id, transformer_id) in self.call_block_id: + ipadapter_id = self.call_block_id[(block_id, transformer_id)] + ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) + if block_id not in ip_kv_dict: + ip_kv_dict[block_id] = {} + ip_kv_dict[block_id][transformer_id] = { + "ip_k": ip_k, + "ip_v": ip_v, + "scale": scale + } + return ip_kv_dict + + def state_dict_converter(self): + return SDXLIpAdapterStateDictConverter() + + +class SDXLIpAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {} + for name in state_dict["ip_adapter"]: + names = name.split(".") + layer_id = str(int(names[0]) // 2) + name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:]) + state_dict_[name_] = state_dict["ip_adapter"][name] + for name in state_dict["image_proj"]: + name_ = "image_proj." + name + state_dict_[name_] = state_dict["image_proj"][name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) + diff --git a/diffsynth/models/svd_image_encoder.py b/diffsynth/models/svd_image_encoder.py index 416aebb..c3aa32f 100644 --- a/diffsynth/models/svd_image_encoder.py +++ b/diffsynth/models/svd_image_encoder.py @@ -25,11 +25,13 @@ class CLIPVisionEmbeddings(torch.nn.Module): class SVDImageEncoder(torch.nn.Module): - def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024): + def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024, num_heads=16, head_dim=80): super().__init__() self.embeddings = CLIPVisionEmbeddings(embed_dim=embed_dim) self.pre_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps) - self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=16, head_dim=80, use_quick_gelu=False) for _ in range(num_encoder_layers)]) + self.encoders = torch.nn.ModuleList([ + CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=num_heads, head_dim=head_dim, use_quick_gelu=False) + for _ in range(num_encoder_layers)]) self.post_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps) self.visual_projection = torch.nn.Linear(embed_dim, projection_dim, bias=False) @@ -78,7 +80,7 @@ class SVDImageEncoderStateDictConverter: if name == "vision_model.embeddings.class_embedding": param = state_dict[name].view(1, 1, -1) elif name == "vision_model.embeddings.position_embedding.weight": - param = state_dict[name].view(1, 257, 1280) + param = state_dict[name].unsqueeze(0) state_dict_[rename_dict[name]] = param elif name.startswith("vision_model.encoder.layers."): param = state_dict[name] diff --git a/diffsynth/pipelines/dancer.py b/diffsynth/pipelines/dancer.py index d19e746..6143d11 100644 --- a/diffsynth/pipelines/dancer.py +++ b/diffsynth/pipelines/dancer.py @@ -119,6 +119,7 @@ def lets_dance_xl( add_text_embeds = None, timestep = None, encoder_hidden_states = None, + ipadapter_kwargs_list = {}, controlnet_frames = None, unet_batch_size = 1, controlnet_batch_size = 1, @@ -151,7 +152,8 @@ def lets_dance_xl( for block_id, block in enumerate(unet.blocks): hidden_states, time_emb, text_emb, res_stack = block( hidden_states, time_emb, text_emb, res_stack, - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}) ) # 4.2 AnimateDiff if motion_modules is not None: diff --git a/diffsynth/pipelines/stable_diffusion_xl.py b/diffsynth/pipelines/stable_diffusion_xl.py index 246a361..ed5ed62 100644 --- a/diffsynth/pipelines/stable_diffusion_xl.py +++ b/diffsynth/pipelines/stable_diffusion_xl.py @@ -1,7 +1,8 @@ -from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder +from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterCLIPImageEmbedder # TODO: SDXL ControlNet from ..prompts import SDXLPrompter from ..schedulers import EnhancedDDIMScheduler +from .dancer import lets_dance_xl import torch from tqdm import tqdm from PIL import Image @@ -22,6 +23,8 @@ class SDXLImagePipeline(torch.nn.Module): self.unet: SDXLUNet = None self.vae_decoder: SDXLVAEDecoder = None self.vae_encoder: SDXLVAEEncoder = None + self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None + self.ipadapter: SDXLIpAdapter = None # TODO: SDXL ControlNet def fetch_main_models(self, model_manager: ModelManager): @@ -35,6 +38,13 @@ class SDXLImagePipeline(torch.nn.Module): def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs): # TODO: SDXL ControlNet pass + + + def fetch_ipadapter(self, model_manager: ModelManager): + if "ipadapter_xl" in model_manager.model: + self.ipadapter = model_manager.ipadapter_xl + if "ipadapter_xl_image_encoder" in model_manager.model: + self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder def fetch_prompter(self, model_manager: ModelManager): @@ -50,6 +60,7 @@ class SDXLImagePipeline(torch.nn.Module): pipe.fetch_main_models(model_manager) pipe.fetch_prompter(model_manager) pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units) + pipe.fetch_ipadapter(model_manager) return pipe @@ -74,6 +85,7 @@ class SDXLImagePipeline(torch.nn.Module): clip_skip=1, clip_skip_2=2, input_image=None, + ipadapter_images=None, controlnet_image=None, denoising_strength=1.0, height=1024, @@ -118,30 +130,38 @@ class SDXLImagePipeline(torch.nn.Module): # Prepare positional id add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) + + # IP-Adapter + if ipadapter_images is not None: + ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) + ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding) + ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding)) + else: + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {} # Denoise for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = torch.IntTensor((timestep,))[0].to(self.device) # Classifier-free guidance + noise_pred_posi = lets_dance_xl( + self.unet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, + add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ipadapter_kwargs_list=ipadapter_kwargs_list_posi, + ) if cfg_scale != 1.0: - noise_pred_posi = self.unet( - latents, timestep, prompt_emb_posi, - add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride - ) - noise_pred_nega = self.unet( - latents, timestep, prompt_emb_nega, + noise_pred_nega = lets_dance_xl( + self.unet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega, - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ipadapter_kwargs_list=ipadapter_kwargs_list_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: - noise_pred = self.unet( - latents, timestep, prompt_emb_posi, - add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride - ) + noise_pred = noise_pred_posi latents = self.scheduler.step(noise_pred, timestep, latents) diff --git a/examples/sdxl_ipadapter.py b/examples/sdxl_ipadapter.py new file mode 100644 index 0000000..706bef0 --- /dev/null +++ b/examples/sdxl_ipadapter.py @@ -0,0 +1,36 @@ +from diffsynth import ModelManager, SDXLImagePipeline +import torch + + +# Download models +# `models/stable_diffusion_xl/sd_xl_base_1.0.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors) +# `models/IpAdapter/image_encoder/model.safetensors`: [link](https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/model.safetensors) +# `models/IpAdapter/ip-adapter_sdxl.bin`: [link](https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl.safetensors) + +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_models([ + "models/stable_diffusion_xl/sd_xl_base_1.0.safetensors", + "models/IpAdapter/image_encoder/model.safetensors", + "models/IpAdapter/ip-adapter_sdxl.bin" +]) +pipe = SDXLImagePipeline.from_model_manager(model_manager) +pipe.ipadapter.set_less_adapter() + +torch.manual_seed(0) +style_image = pipe( + prompt="Starry Night, blue sky, by van Gogh", + negative_prompt="dark, gray", + cfg_scale=5, + height=1024, width=1024, num_inference_steps=30, +) +style_image.save("style_image.jpg") + +image = pipe( + prompt="a cat", + negative_prompt="", + cfg_scale=5, + height=1024, width=1024, num_inference_steps=30, + ipadapter_images=[style_image] +) +image.save("transferred_image.jpg")