support flux ipadapter

This commit is contained in:
root
2024-11-26 18:08:50 +08:00
parent 5fc9e53eec
commit 4f40683fd8
6 changed files with 133 additions and 19 deletions

View File

@@ -1,4 +1,4 @@
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler
@@ -9,7 +9,7 @@ from tqdm import tqdm
import numpy as np
from PIL import Image
from ..models.tiler import FastTileWorker
from transformers import SiglipVisionModel
class FluxImagePipeline(BasePipeline):
@@ -25,7 +25,9 @@ class FluxImagePipeline(BasePipeline):
self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None
self.controlnet: FluxMultiControlNetManager = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet']
self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
def denoising_model(self):
@@ -53,6 +55,9 @@ class FluxImagePipeline(BasePipeline):
controlnet_units.append(controlnet_unit)
self.controlnet = FluxMultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
@@ -129,18 +134,24 @@ class FluxImagePipeline(BasePipeline):
controlnet_frames.append(image)
return controlnet_frames
def prepare_ipadapter_inputs(self, images, height=384, width=384):
images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
return torch.cat(images, dim=0)
@torch.no_grad()
def __call__(
self,
prompt,
local_prompts=None,
masks=None,
masks=None,
mask_scales=None,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=3.5,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
@@ -157,7 +168,7 @@ class FluxImagePipeline(BasePipeline):
progress_bar_st=None,
):
height, width = self.check_resize_height_width(height, width)
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
@@ -187,6 +198,17 @@ class FluxImagePipeline(BasePipeline):
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# IP-Adapter
if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['vae_encoder'])
@@ -208,7 +230,7 @@ class FluxImagePipeline(BasePipeline):
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -219,7 +241,7 @@ class FluxImagePipeline(BasePipeline):
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -256,6 +278,7 @@ def lets_dance_flux(
tiled=False,
tile_size=128,
tile_stride=64,
ipadapter_kwargs_list={},
**kwargs
):
if tiled:
@@ -319,15 +342,27 @@ def lets_dance_flux(
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None))
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(
block_id + num_joint_blocks, None))
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]