support flux ipadapter

This commit is contained in:
root
2024-11-26 18:08:50 +08:00
parent 5fc9e53eec
commit 4f40683fd8
6 changed files with 133 additions and 19 deletions

View File

@@ -36,6 +36,7 @@ from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder2 from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.flux_controlnet import FluxControlNet from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT from ..models.cog_dit import CogDiT
@@ -88,6 +89,7 @@ model_loader_configs = [
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"), (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"), (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"), (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
@@ -102,6 +104,7 @@ huggingface_model_loader_configs = [
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None), ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"), ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"), ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel")
] ]
patch_model_loader_configs = [ patch_model_loader_configs = [
# These configs are provided for detecting model type automatically. # These configs are provided for detecting model type automatically.
@@ -257,6 +260,17 @@ preset_models_on_huggingface = {
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"), ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
], ],
"InstantX/FLUX.1-dev-IP-Adapter": {
"file_list": [
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
],
"load_path": [
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
],
},
# RIFE # RIFE
"RIFE": [ "RIFE": [
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"), ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
@@ -541,6 +555,17 @@ preset_models_on_modelscope = {
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"), ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
], ],
"InstantX/FLUX.1-dev-IP-Adapter": {
"file_list": [
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
],
"load_path": [
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
],
},
# ESRGAN # ESRGAN
"ESRGAN_x4": [ "ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"), ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
@@ -642,6 +667,7 @@ Preset_model_id: TypeAlias = Literal[
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
"InstantX/FLUX.1-dev-IP-Adapter",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"QwenPrompt", "QwenPrompt",
"OmostPrompt", "OmostPrompt",

View File

@@ -4,6 +4,12 @@ from einops import rearrange
from .tiler import TileWorker from .tiler import TileWorker
from .utils import init_weights_on_device from .utils import init_weights_on_device
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
batch_size, num_tokens = hidden_states.shape[0:2]
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
class RoPEEmbedding(torch.nn.Module): class RoPEEmbedding(torch.nn.Module):
@@ -64,8 +70,7 @@ class FluxJointAttention(torch.nn.Module):
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None):
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
batch_size = hidden_states_a.shape[0] batch_size = hidden_states_a.shape[0]
# Part A # Part A
@@ -90,6 +95,8 @@ class FluxJointAttention(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype) hidden_states = hidden_states.to(q.dtype)
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
if ipadapter_kwargs_list is not None:
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
hidden_states_a = self.a_to_out(hidden_states_a) hidden_states_a = self.a_to_out(hidden_states_a)
if self.only_out_a: if self.only_out_a:
return hidden_states_a return hidden_states_a
@@ -122,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module):
) )
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb): def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention # Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb) attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list)
# Part A # Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
@@ -219,7 +226,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def process_attention(self, hidden_states, image_rotary_emb): def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
@@ -231,16 +238,18 @@ class FluxSingleTransformerBlock(torch.nn.Module):
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype) hidden_states = hidden_states.to(q.dtype)
if ipadapter_kwargs_list is not None:
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
return hidden_states return hidden_states
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb): def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
residual = hidden_states_a residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states) hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb) attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list)
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)

View File

@@ -39,6 +39,7 @@ from .hunyuan_dit import HunyuanDiT
from .flux_dit import FluxDiT from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder2 from .flux_text_encoder import FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from .flux_ipadapter import FluxIpAdapter
from .cog_vae import CogVAEEncoder, CogVAEDecoder from .cog_vae import CogVAEEncoder, CogVAEDecoder
from .cog_dit import CogDiT from .cog_dit import CogDiT

View File

@@ -6,16 +6,21 @@ from .tiler import TileWorker
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps): def __init__(self, dim, eps, elementwise_affine=True):
super().__init__() super().__init__()
self.weight = torch.nn.Parameter(torch.ones((dim,)))
self.eps = eps self.eps = eps
if elementwise_affine:
self.weight = torch.nn.Parameter(torch.ones((dim,)))
else:
self.weight = None
def forward(self, hidden_states): def forward(self, hidden_states):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype) * self.weight hidden_states = hidden_states.to(input_dtype)
if self.weight is not None:
hidden_states = hidden_states * self.weight
return hidden_states return hidden_states

View File

@@ -1,4 +1,4 @@
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompters import FluxPrompter from ..prompters import FluxPrompter
from ..schedulers import FlowMatchScheduler from ..schedulers import FlowMatchScheduler
@@ -9,7 +9,7 @@ from tqdm import tqdm
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from ..models.tiler import FastTileWorker from ..models.tiler import FastTileWorker
from transformers import SiglipVisionModel
class FluxImagePipeline(BasePipeline): class FluxImagePipeline(BasePipeline):
@@ -25,7 +25,9 @@ class FluxImagePipeline(BasePipeline):
self.vae_decoder: FluxVAEDecoder = None self.vae_decoder: FluxVAEDecoder = None
self.vae_encoder: FluxVAEEncoder = None self.vae_encoder: FluxVAEEncoder = None
self.controlnet: FluxMultiControlNetManager = None self.controlnet: FluxMultiControlNetManager = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet'] self.ipadapter: FluxIpAdapter = None
self.ipadapter_image_encoder: SiglipVisionModel = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
def denoising_model(self): def denoising_model(self):
@@ -53,6 +55,9 @@ class FluxImagePipeline(BasePipeline):
controlnet_units.append(controlnet_unit) controlnet_units.append(controlnet_unit)
self.controlnet = FluxMultiControlNetManager(controlnet_units) self.controlnet = FluxMultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
@staticmethod @staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None): def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
@@ -129,6 +134,10 @@ class FluxImagePipeline(BasePipeline):
controlnet_frames.append(image) controlnet_frames.append(image)
return controlnet_frames return controlnet_frames
def prepare_ipadapter_inputs(self, images, height=384, width=384):
images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
return torch.cat(images, dim=0)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
@@ -141,6 +150,8 @@ class FluxImagePipeline(BasePipeline):
cfg_scale=1.0, cfg_scale=1.0,
embedded_guidance=3.5, embedded_guidance=3.5,
input_image=None, input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None, controlnet_image=None,
controlnet_inpaint_mask=None, controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False, enable_controlnet_on_negative=False,
@@ -187,6 +198,17 @@ class FluxImagePipeline(BasePipeline):
# Extra input # Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# IP-Adapter
if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets # Prepare ControlNets
if controlnet_image is not None: if controlnet_image is not None:
self.load_models_to_device(['vae_encoder']) self.load_models_to_device(['vae_encoder'])
@@ -208,7 +230,7 @@ class FluxImagePipeline(BasePipeline):
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet, dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep, hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi,
) )
noise_pred_posi = self.control_noise_via_local_prompts( noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -219,7 +241,7 @@ class FluxImagePipeline(BasePipeline):
noise_pred_nega = lets_dance_flux( noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet, dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep, hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega,
) )
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
@@ -256,6 +278,7 @@ def lets_dance_flux(
tiled=False, tiled=False,
tile_size=128, tile_size=128,
tile_stride=64, tile_stride=64,
ipadapter_kwargs_list={},
**kwargs **kwargs
): ):
if tiled: if tiled:
@@ -319,15 +342,27 @@ def lets_dance_flux(
# Joint Blocks # Joint Blocks
for block_id, block in enumerate(dit.blocks): for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None))
# ControlNet # ControlNet
if controlnet is not None and controlnet_frames is not None: if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id] hidden_states = hidden_states + controlnet_res_stack[block_id]
# Single Blocks # Single Blocks
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
num_joint_blocks = len(dit.blocks)
for block_id, block in enumerate(dit.single_blocks): for block_id, block in enumerate(dit.single_blocks):
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) hidden_states, prompt_emb = block(
hidden_states,
prompt_emb,
conditioning,
image_rotary_emb,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(
block_id + num_joint_blocks, None))
# ControlNet # ControlNet
if controlnet is not None and controlnet_frames is not None: if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]

View File

@@ -0,0 +1,38 @@
from diffsynth import ModelManager, download_models, FluxImagePipeline
import torch
# Download models (automatically)
# `models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin`: [link](https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter/blob/main/ip-adapter.bin)
# `models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder`: [link](https://huggingface.co/google/siglip-so400m-patch14-384)
download_models(["InstantX/FLUX.1-dev-IP-Adapter"])
# Load models
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
])
seed = 42
pipe = FluxImagePipeline.from_model_manager(model_manager)
torch.manual_seed(seed)
origin_prompt = "a rabbit in a garden, colorful flowers"
image = pipe(
prompt=origin_prompt,
cfg_scale=1.0, embedded_guidance=3.5,
height=1280, width=960, num_inference_steps=30
)
image.save("style image.jpg")
torch.manual_seed(seed)
image = pipe(
prompt="A piggy",
cfg_scale=1.0, embedded_guidance=3.5,
height=1280, width=960, num_inference_steps=30,
ipadapter_images=[image], ipadapter_scale=0.7
)
image.save("A piggy.jpg")