Merge pull request #275 from mi804/flux_ipadapter

Flux ipadapter
This commit is contained in:
Zhongjie Duan
2024-11-28 10:43:06 +08:00
committed by GitHub
7 changed files with 227 additions and 19 deletions

View File

@@ -36,6 +36,7 @@ from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
@@ -88,6 +89,7 @@ model_loader_configs = [
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
@@ -102,6 +104,7 @@ huggingface_model_loader_configs = [
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel")
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -257,6 +260,17 @@ preset_models_on_huggingface = {
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
],
"InstantX/FLUX.1-dev-IP-Adapter": {
"file_list": [
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
],
"load_path": [
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
],
},
# RIFE
"RIFE": [
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
@@ -541,6 +555,17 @@ preset_models_on_modelscope = {
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
],
"InstantX/FLUX.1-dev-IP-Adapter": {
"file_list": [
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
],
"load_path": [
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
],
},
# ESRGAN
"ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
@@ -642,6 +667,7 @@ Preset_model_id: TypeAlias = Literal[
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
"InstantX/FLUX.1-dev-IP-Adapter",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt",
"OmostPrompt",

View File

@@ -4,6 +4,12 @@ from einops import rearrange
from .tiler import TileWorker
from .utils import init_weights_on_device
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size, num_tokens = hidden_states.shape[0:2]
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
class RoPEEmbedding(torch.nn.Module):
@@ -64,8 +70,7 @@ class FluxJointAttention(torch.nn.Module):
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None):
batch_size = hidden_states_a.shape[0]
# Part A
@@ -90,6 +95,8 @@ class FluxJointAttention(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
if ipadapter_kwargs_list is not None:
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
hidden_states_a = self.a_to_out(hidden_states_a)
if self.only_out_a:
return hidden_states_a
@@ -122,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module):
)
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb)
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
@@ -219,7 +226,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def process_attention(self, hidden_states, image_rotary_emb):
def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None):
batch_size = hidden_states.shape[0]
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
@@ -231,16 +238,18 @@ class FluxSingleTransformerBlock(torch.nn.Module):
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
if ipadapter_kwargs_list is not None:
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
return hidden_states
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb)
attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list)
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)

View File

@@ -0,0 +1,94 @@
from .svd_image_encoder import SVDImageEncoder
from .sd3_dit import RMSNorm
from transformers import CLIPImageProcessor
import torch
class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x
class IpAdapterModule(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
super().__init__()
self.num_heads = num_attention_heads
self.head_dim = attention_head_dim
output_dim = num_attention_heads * attention_head_dim
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
def forward(self, hidden_states):
batch_size = hidden_states.shape[0]
# ip_k
ip_k = self.to_k_ip(hidden_states)
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
ip_k = self.norm_added_k(ip_k)
# ip_v
ip_v = self.to_v_ip(hidden_states)
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
return ip_k, ip_v
class FluxIpAdapter(torch.nn.Module):
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
super().__init__()
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
self.set_adapter()
def set_adapter(self):
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
def forward(self, hidden_states, scale=1.0):
hidden_states = self.image_proj(hidden_states)
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
ip_kv_dict = {}
for block_id in self.call_block_id:
ipadapter_id = self.call_block_id[block_id]
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
ip_kv_dict[block_id] = {
"ip_k": ip_k,
"ip_v": ip_v,
"scale": scale
}
return ip_kv_dict
@staticmethod
def state_dict_converter():
return FluxIpAdapterStateDictConverter()
class FluxIpAdapterStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name in state_dict["ip_adapter"]:
name_ = 'ipadapter_modules.' + name
state_dict_[name_] = state_dict["ip_adapter"][name]
for name in state_dict["image_proj"]:
name_ = "image_proj." + name
state_dict_[name_] = state_dict["image_proj"][name]
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -39,6 +39,7 @@ from .hunyuan_dit import HunyuanDiT
from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from .flux_ipadapter import FluxIpAdapter
from .cog_vae import CogVAEEncoder, CogVAEDecoder
from .cog_dit import CogDiT

View File

@@ -6,16 +6,21 @@ from .tiler import TileWorker
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps):
def __init__(self, dim, eps, elementwise_affine=True):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones((dim,)))
self.eps = eps
if elementwise_affine:
self.weight = torch.nn.Parameter(torch.ones((dim,)))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype) * self.weight
hidden_states = hidden_states.to(input_dtype)
if self.weight is not None:
hidden_states = hidden_states * self.weight
return hidden_states

View File

@@ -1,4 +1,4 @@
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
@@ -9,7 +9,7 @@ from tqdm import tqdm
import numpy as np
from PIL import Image
from ..models.tiler import FastTileWorker
from transformers import SiglipVisionModel
class FluxImagePipeline(BasePipeline):
@@ -25,7 +25,9 @@ class FluxImagePipeline(BasePipeline):
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
self.controlnet: FluxMultiControlNetManager = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet']
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
def denoising_model(self):
@@ -53,6 +55,9 @@ class FluxImagePipeline(BasePipeline):
controlnet_units.append(controlnet_unit)
self.controlnet = FluxMultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
@@ -129,18 +134,24 @@ class FluxImagePipeline(BasePipeline):
controlnet_frames.append(image)
return controlnet_frames
def prepare_ipadapter_inputs(self, images, height=384, width=384):
images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
return torch.cat(images, dim=0)
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts=None,
masks=None,
masks=None,
mask_scales=None,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=3.5,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
@@ -157,7 +168,7 @@ class FluxImagePipeline(BasePipeline):
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
@@ -187,6 +198,17 @@ class FluxImagePipeline(BasePipeline):
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# IP-Adapter
if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['vae_encoder'])
@@ -208,7 +230,7 @@ class FluxImagePipeline(BasePipeline):
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -219,7 +241,7 @@ class FluxImagePipeline(BasePipeline):
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -256,6 +278,7 @@ def lets_dance_flux(
tiled=False,
tile_size=128,
tile_stride=64,
ipadapter_kwargs_list={},
**kwargs
):
if tiled:
@@ -319,15 +342,27 @@ def lets_dance_flux(
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None))
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(
block_id + num_joint_blocks, None))
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]

View File

@@ -0,0 +1,38 @@
from diffsynth import ModelManager, download_models, FluxImagePipeline
import torch
# Download models (automatically)
# `models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin`: [link](https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter/blob/main/ip-adapter.bin)
# `models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder`: [link](https://huggingface.co/google/siglip-so400m-patch14-384)
download_models(["InstantX/FLUX.1-dev-IP-Adapter"])
# Load models
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
])
seed = 42
pipe = FluxImagePipeline.from_model_manager(model_manager)
torch.manual_seed(seed)
origin_prompt = "a rabbit in a garden, colorful flowers"
image = pipe(
prompt=origin_prompt,
cfg_scale=1.0, embedded_guidance=3.5,
height=1280, width=960, num_inference_steps=30
)
image.save("style image.jpg")
torch.manual_seed(seed)
image = pipe(
prompt="A piggy",
cfg_scale=1.0, embedded_guidance=3.5,
height=1280, width=960, num_inference_steps=30,
ipadapter_images=[image], ipadapter_scale=0.7
)
image.save("A piggy.jpg")