From 4f40683fd8d8d8c8bc7d6da7dc5c03fbba17fa6d Mon Sep 17 00:00:00 2001 From: root <1576993271@qq.com> Date: Tue, 26 Nov 2024 18:08:50 +0800 Subject: [PATCH 1/2] support flux ipadapter --- diffsynth/configs/model_config.py | 26 +++++++++++++ diffsynth/models/flux_dit.py | 23 ++++++++---- diffsynth/models/model_manager.py | 1 + diffsynth/models/sd3_dit.py | 11 ++++-- diffsynth/pipelines/flux_image.py | 53 ++++++++++++++++++++++----- examples/Ip-Adapter/flux_ipadapter.py | 38 +++++++++++++++++++ 6 files changed, 133 insertions(+), 19 deletions(-) create mode 100644 examples/Ip-Adapter/flux_ipadapter.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index b5057d6..b133c5e 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -36,6 +36,7 @@ from ..models.flux_dit import FluxDiT from ..models.flux_text_encoder import FluxTextEncoder2 from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.flux_controlnet import FluxControlNet +from ..models.flux_ipadapter import FluxIpAdapter from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder from ..models.cog_dit import CogDiT @@ -88,6 +89,7 @@ model_loader_configs = [ (None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"), + (None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"), (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"), (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), @@ -102,6 +104,7 @@ huggingface_model_loader_configs = [ ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None), ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"), ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"), + ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel") ] patch_model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -257,6 +260,17 @@ preset_models_on_huggingface = { ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"), ], + "InstantX/FLUX.1-dev-IP-Adapter": { + "file_list": [ + ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"), + ("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"), + ("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"), + ], + "load_path": [ + "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin", + "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder", + ], + }, # RIFE "RIFE": [ ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"), @@ -541,6 +555,17 @@ preset_models_on_modelscope = { "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [ ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"), ], + "InstantX/FLUX.1-dev-IP-Adapter": { + "file_list": [ + ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"), + ("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"), + ("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"), + ], + "load_path": [ + "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin", + "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder", + ], + }, # ESRGAN "ESRGAN_x4": [ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"), @@ -642,6 +667,7 @@ Preset_model_id: TypeAlias = Literal[ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", + "InstantX/FLUX.1-dev-IP-Adapter", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "QwenPrompt", "OmostPrompt", diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 4116d3c..faf58cd 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -4,6 +4,12 @@ from einops import rearrange from .tiler import TileWorker from .utils import init_weights_on_device +def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size, num_tokens = hidden_states.shape[0:2] + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states class RoPEEmbedding(torch.nn.Module): @@ -64,8 +70,7 @@ class FluxJointAttention(torch.nn.Module): xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - - def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb): + def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None): batch_size = hidden_states_a.shape[0] # Part A @@ -90,6 +95,8 @@ class FluxJointAttention(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] + if ipadapter_kwargs_list is not None: + hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list) hidden_states_a = self.a_to_out(hidden_states_a) if self.only_out_a: return hidden_states_a @@ -122,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module): ) - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None): norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) # Attention - attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb) + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list) # Part A hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a @@ -219,7 +226,7 @@ class FluxSingleTransformerBlock(torch.nn.Module): return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - def process_attention(self, hidden_states, image_rotary_emb): + def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None): batch_size = hidden_states.shape[0] qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) @@ -231,16 +238,18 @@ class FluxSingleTransformerBlock(torch.nn.Module): hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) + if ipadapter_kwargs_list is not None: + hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list) return hidden_states - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None): residual = hidden_states_a norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) hidden_states_a = self.to_qkv_mlp(norm_hidden_states) attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] - attn_output = self.process_attention(attn_output, image_rotary_emb) + attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list) mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 50b4a92..f8351d2 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -39,6 +39,7 @@ from .hunyuan_dit import HunyuanDiT from .flux_dit import FluxDiT from .flux_text_encoder import FluxTextEncoder2 from .flux_vae import FluxVAEEncoder, FluxVAEDecoder +from .flux_ipadapter import FluxIpAdapter from .cog_vae import CogVAEEncoder, CogVAEDecoder from .cog_dit import CogDiT diff --git a/diffsynth/models/sd3_dit.py b/diffsynth/models/sd3_dit.py index 6168088..730e6fc 100644 --- a/diffsynth/models/sd3_dit.py +++ b/diffsynth/models/sd3_dit.py @@ -6,16 +6,21 @@ from .tiler import TileWorker class RMSNorm(torch.nn.Module): - def __init__(self, dim, eps): + def __init__(self, dim, eps, elementwise_affine=True): super().__init__() - self.weight = torch.nn.Parameter(torch.ones((dim,))) self.eps = eps + if elementwise_affine: + self.weight = torch.nn.Parameter(torch.ones((dim,))) + else: + self.weight = None def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - hidden_states = hidden_states.to(input_dtype) * self.weight + hidden_states = hidden_states.to(input_dtype) + if self.weight is not None: + hidden_states = hidden_states * self.weight return hidden_states diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 69b7a16..42d142c 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -1,4 +1,4 @@ -from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder +from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..prompters import FluxPrompter from ..schedulers import FlowMatchScheduler @@ -9,7 +9,7 @@ from tqdm import tqdm import numpy as np from PIL import Image from ..models.tiler import FastTileWorker - +from transformers import SiglipVisionModel class FluxImagePipeline(BasePipeline): @@ -25,7 +25,9 @@ class FluxImagePipeline(BasePipeline): self.vae_decoder: FluxVAEDecoder = None self.vae_encoder: FluxVAEEncoder = None self.controlnet: FluxMultiControlNetManager = None - self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet'] + self.ipadapter: FluxIpAdapter = None + self.ipadapter_image_encoder: SiglipVisionModel = None + self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder'] def denoising_model(self): @@ -53,6 +55,9 @@ class FluxImagePipeline(BasePipeline): controlnet_units.append(controlnet_unit) self.controlnet = FluxMultiControlNetManager(controlnet_units) + # IP-Adapters + self.ipadapter = model_manager.fetch_model("flux_ipadapter") + self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") @staticmethod def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None): @@ -129,18 +134,24 @@ class FluxImagePipeline(BasePipeline): controlnet_frames.append(image) return controlnet_frames + def prepare_ipadapter_inputs(self, images, height=384, width=384): + images = [image.convert("RGB").resize((width, height), resample=3) for image in images] + images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images] + return torch.cat(images, dim=0) @torch.no_grad() def __call__( self, prompt, local_prompts=None, - masks=None, + masks=None, mask_scales=None, negative_prompt="", cfg_scale=1.0, embedded_guidance=3.5, input_image=None, + ipadapter_images=None, + ipadapter_scale=1.0, controlnet_image=None, controlnet_inpaint_mask=None, enable_controlnet_on_negative=False, @@ -157,7 +168,7 @@ class FluxImagePipeline(BasePipeline): progress_bar_st=None, ): height, width = self.check_resize_height_width(height, width) - + # Tiler parameters tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} @@ -187,6 +198,17 @@ class FluxImagePipeline(BasePipeline): # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) + # IP-Adapter + if ipadapter_images is not None: + self.load_models_to_device(['ipadapter_image_encoder']) + ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images) + ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output + self.load_models_to_device(['ipadapter']) + ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} + ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} + else: + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}} + # Prepare ControlNets if controlnet_image is not None: self.load_models_to_device(['vae_encoder']) @@ -208,7 +230,7 @@ class FluxImagePipeline(BasePipeline): inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, @@ -219,7 +241,7 @@ class FluxImagePipeline(BasePipeline): noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, + **prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -256,6 +278,7 @@ def lets_dance_flux( tiled=False, tile_size=128, tile_stride=64, + ipadapter_kwargs_list={}, **kwargs ): if tiled: @@ -319,15 +342,27 @@ def lets_dance_flux( # Joint Blocks for block_id, block in enumerate(dit.blocks): - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states = hidden_states + controlnet_res_stack[block_id] # Single Blocks hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + num_joint_blocks = len(dit.blocks) for block_id, block in enumerate(dit.single_blocks): - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + hidden_states, prompt_emb = block( + hidden_states, + prompt_emb, + conditioning, + image_rotary_emb, + ipadapter_kwargs_list=ipadapter_kwargs_list.get( + block_id + num_joint_blocks, None)) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] diff --git a/examples/Ip-Adapter/flux_ipadapter.py b/examples/Ip-Adapter/flux_ipadapter.py new file mode 100644 index 0000000..09639e8 --- /dev/null +++ b/examples/Ip-Adapter/flux_ipadapter.py @@ -0,0 +1,38 @@ +from diffsynth import ModelManager, download_models, FluxImagePipeline +import torch + +# Download models (automatically) +# `models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin`: [link](https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter/blob/main/ip-adapter.bin) +# `models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder`: [link](https://huggingface.co/google/siglip-so400m-patch14-384) +download_models(["InstantX/FLUX.1-dev-IP-Adapter"]) + +# Load models +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin", + "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder", + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-dev/flux1-dev.safetensors", +]) +seed = 42 +pipe = FluxImagePipeline.from_model_manager(model_manager) +torch.manual_seed(seed) +origin_prompt = "a rabbit in a garden, colorful flowers" +image = pipe( + prompt=origin_prompt, + cfg_scale=1.0, embedded_guidance=3.5, + height=1280, width=960, num_inference_steps=30 +) +image.save("style image.jpg") + +torch.manual_seed(seed) +image = pipe( + prompt="A piggy", + cfg_scale=1.0, embedded_guidance=3.5, + height=1280, width=960, num_inference_steps=30, + ipadapter_images=[image], ipadapter_scale=0.7 +) +image.save("A piggy.jpg") + From f2130c4c25ff2074e32bb5cb5cc05955f2377427 Mon Sep 17 00:00:00 2001 From: root <1576993271@qq.com> Date: Tue, 26 Nov 2024 19:08:41 +0800 Subject: [PATCH 2/2] minor --- diffsynth/models/flux_ipadapter.py | 94 ++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 diffsynth/models/flux_ipadapter.py diff --git a/diffsynth/models/flux_ipadapter.py b/diffsynth/models/flux_ipadapter.py new file mode 100644 index 0000000..575c752 --- /dev/null +++ b/diffsynth/models/flux_ipadapter.py @@ -0,0 +1,94 @@ +from .svd_image_encoder import SVDImageEncoder +from .sd3_dit import RMSNorm +from transformers import CLIPImageProcessor +import torch + + +class MLPProjModel(torch.nn.Module): + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + return x + +class IpAdapterModule(torch.nn.Module): + def __init__(self, num_attention_heads, attention_head_dim, input_dim): + super().__init__() + self.num_heads = num_attention_heads + self.head_dim = attention_head_dim + output_dim = num_attention_heads * attention_head_dim + self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False) + self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False) + + + def forward(self, hidden_states): + batch_size = hidden_states.shape[0] + # ip_k + ip_k = self.to_k_ip(hidden_states) + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_k = self.norm_added_k(ip_k) + # ip_v + ip_v = self.to_v_ip(hidden_states) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + return ip_k, ip_v + + +class FluxIpAdapter(torch.nn.Module): + def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57): + super().__init__() + self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)]) + self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens) + self.set_adapter() + + def set_adapter(self): + self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))} + + def forward(self, hidden_states, scale=1.0): + hidden_states = self.image_proj(hidden_states) + hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1]) + ip_kv_dict = {} + for block_id in self.call_block_id: + ipadapter_id = self.call_block_id[block_id] + ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states) + ip_kv_dict[block_id] = { + "ip_k": ip_k, + "ip_v": ip_v, + "scale": scale + } + return ip_kv_dict + + @staticmethod + def state_dict_converter(): + return FluxIpAdapterStateDictConverter() + + +class FluxIpAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {} + for name in state_dict["ip_adapter"]: + name_ = 'ipadapter_modules.' + name + state_dict_[name_] = state_dict["ip_adapter"][name] + for name in state_dict["image_proj"]: + name_ = "image_proj." + name + state_dict_[name_] = state_dict["image_proj"][name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict)