mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -36,6 +36,7 @@ from ..models.flux_dit import FluxDiT
|
||||
from ..models.flux_text_encoder import FluxTextEncoder2
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.flux_controlnet import FluxControlNet
|
||||
from ..models.flux_ipadapter import FluxIpAdapter
|
||||
|
||||
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from ..models.cog_dit import CogDiT
|
||||
@@ -88,6 +89,7 @@ model_loader_configs = [
|
||||
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
||||
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
@@ -102,6 +104,7 @@ huggingface_model_loader_configs = [
|
||||
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel")
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -257,6 +260,17 @@ preset_models_on_huggingface = {
|
||||
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
||||
],
|
||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
||||
"file_list": [
|
||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
||||
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||
],
|
||||
},
|
||||
# RIFE
|
||||
"RIFE": [
|
||||
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
||||
@@ -541,6 +555,17 @@ preset_models_on_modelscope = {
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
||||
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
||||
],
|
||||
"InstantX/FLUX.1-dev-IP-Adapter": {
|
||||
"file_list": [
|
||||
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
||||
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
||||
],
|
||||
"load_path": [
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||
],
|
||||
},
|
||||
# ESRGAN
|
||||
"ESRGAN_x4": [
|
||||
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
||||
@@ -642,6 +667,7 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
||||
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||
"InstantX/FLUX.1-dev-IP-Adapter",
|
||||
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
||||
"QwenPrompt",
|
||||
"OmostPrompt",
|
||||
|
||||
@@ -4,6 +4,12 @@ from einops import rearrange
|
||||
from .tiler import TileWorker
|
||||
from .utils import init_weights_on_device
|
||||
|
||||
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RoPEEmbedding(torch.nn.Module):
|
||||
@@ -64,8 +70,7 @@ class FluxJointAttention(torch.nn.Module):
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb):
|
||||
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
# Part A
|
||||
@@ -90,6 +95,8 @@ class FluxJointAttention(torch.nn.Module):
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
||||
if ipadapter_kwargs_list is not None:
|
||||
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
|
||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||
if self.only_out_a:
|
||||
return hidden_states_a
|
||||
@@ -122,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb)
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
@@ -219,7 +226,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
def process_attention(self, hidden_states, image_rotary_emb):
|
||||
def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None):
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -231,16 +238,18 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
if ipadapter_kwargs_list is not None:
|
||||
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
|
||||
residual = hidden_states_a
|
||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
||||
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb)
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list)
|
||||
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
||||
|
||||
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
|
||||
94
diffsynth/models/flux_ipadapter.py
Normal file
94
diffsynth/models/flux_ipadapter.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .sd3_dit import RMSNorm
|
||||
from transformers import CLIPImageProcessor
|
||||
import torch
|
||||
|
||||
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
||||
)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, id_embeds):
|
||||
x = self.proj(id_embeds)
|
||||
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class IpAdapterModule(torch.nn.Module):
|
||||
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
|
||||
super().__init__()
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = attention_head_dim
|
||||
output_dim = num_attention_heads * attention_head_dim
|
||||
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
|
||||
|
||||
|
||||
def forward(self, hidden_states):
|
||||
batch_size = hidden_states.shape[0]
|
||||
# ip_k
|
||||
ip_k = self.to_k_ip(hidden_states)
|
||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_k = self.norm_added_k(ip_k)
|
||||
# ip_v
|
||||
ip_v = self.to_v_ip(hidden_states)
|
||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
return ip_k, ip_v
|
||||
|
||||
|
||||
class FluxIpAdapter(torch.nn.Module):
|
||||
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
|
||||
super().__init__()
|
||||
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
|
||||
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
|
||||
self.set_adapter()
|
||||
|
||||
def set_adapter(self):
|
||||
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
|
||||
|
||||
def forward(self, hidden_states, scale=1.0):
|
||||
hidden_states = self.image_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
||||
ip_kv_dict = {}
|
||||
for block_id in self.call_block_id:
|
||||
ipadapter_id = self.call_block_id[block_id]
|
||||
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
||||
ip_kv_dict[block_id] = {
|
||||
"ip_k": ip_k,
|
||||
"ip_v": ip_v,
|
||||
"scale": scale
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
class FluxIpAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict["ip_adapter"]:
|
||||
name_ = 'ipadapter_modules.' + name
|
||||
state_dict_[name_] = state_dict["ip_adapter"][name]
|
||||
for name in state_dict["image_proj"]:
|
||||
name_ = "image_proj." + name
|
||||
state_dict_[name_] = state_dict["image_proj"][name]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -39,6 +39,7 @@ from .hunyuan_dit import HunyuanDiT
|
||||
from .flux_dit import FluxDiT
|
||||
from .flux_text_encoder import FluxTextEncoder2
|
||||
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from .flux_ipadapter import FluxIpAdapter
|
||||
|
||||
from .cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from .cog_dit import CogDiT
|
||||
|
||||
@@ -6,16 +6,21 @@ from .tiler import TileWorker
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps):
|
||||
def __init__(self, dim, eps, elementwise_affine=True):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * self.weight
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.weight is not None:
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder
|
||||
from ..models import ModelManager, FluxDiT, SD3TextEncoder1, FluxTextEncoder2, FluxVAEDecoder, FluxVAEEncoder, FluxIpAdapter
|
||||
from ..controlnets import FluxMultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompters import FluxPrompter
|
||||
from ..schedulers import FlowMatchScheduler
|
||||
@@ -9,7 +9,7 @@ from tqdm import tqdm
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from ..models.tiler import FastTileWorker
|
||||
|
||||
from transformers import SiglipVisionModel
|
||||
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
@@ -25,7 +25,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.controlnet: FluxMultiControlNetManager = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet']
|
||||
self.ipadapter: FluxIpAdapter = None
|
||||
self.ipadapter_image_encoder: SiglipVisionModel = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
@@ -53,6 +55,9 @@ class FluxImagePipeline(BasePipeline):
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = FluxMultiControlNetManager(controlnet_units)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
|
||||
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
|
||||
@@ -129,18 +134,24 @@ class FluxImagePipeline(BasePipeline):
|
||||
controlnet_frames.append(image)
|
||||
return controlnet_frames
|
||||
|
||||
def prepare_ipadapter_inputs(self, images, height=384, width=384):
|
||||
images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
|
||||
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
|
||||
return torch.cat(images, dim=0)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
local_prompts=None,
|
||||
masks=None,
|
||||
masks=None,
|
||||
mask_scales=None,
|
||||
negative_prompt="",
|
||||
cfg_scale=1.0,
|
||||
embedded_guidance=3.5,
|
||||
input_image=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
controlnet_image=None,
|
||||
controlnet_inpaint_mask=None,
|
||||
enable_controlnet_on_negative=False,
|
||||
@@ -157,7 +168,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
progress_bar_st=None,
|
||||
):
|
||||
height, width = self.check_resize_height_width(height, width)
|
||||
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
@@ -187,6 +198,17 @@ class FluxImagePipeline(BasePipeline):
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
self.load_models_to_device(['ipadapter_image_encoder'])
|
||||
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
|
||||
self.load_models_to_device(['ipadapter'])
|
||||
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
|
||||
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_image is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
@@ -208,7 +230,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs
|
||||
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(
|
||||
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
|
||||
@@ -219,7 +241,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
noise_pred_nega = lets_dance_flux(
|
||||
dit=self.dit, controlnet=self.controlnet,
|
||||
hidden_states=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs,
|
||||
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -256,6 +278,7 @@ def lets_dance_flux(
|
||||
tiled=False,
|
||||
tile_size=128,
|
||||
tile_stride=64,
|
||||
ipadapter_kwargs_list={},
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
@@ -319,15 +342,27 @@ def lets_dance_flux(
|
||||
|
||||
# Joint Blocks
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None))
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
|
||||
# Single Blocks
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
num_joint_blocks = len(dit.blocks)
|
||||
for block_id, block in enumerate(dit.single_blocks):
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(
|
||||
block_id + num_joint_blocks, None))
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
|
||||
38
examples/Ip-Adapter/flux_ipadapter.py
Normal file
38
examples/Ip-Adapter/flux_ipadapter.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from diffsynth import ModelManager, download_models, FluxImagePipeline
|
||||
import torch
|
||||
|
||||
# Download models (automatically)
|
||||
# `models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin`: [link](https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter/blob/main/ip-adapter.bin)
|
||||
# `models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder`: [link](https://huggingface.co/google/siglip-so400m-patch14-384)
|
||||
download_models(["InstantX/FLUX.1-dev-IP-Adapter"])
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
||||
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
|
||||
])
|
||||
seed = 42
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
torch.manual_seed(seed)
|
||||
origin_prompt = "a rabbit in a garden, colorful flowers"
|
||||
image = pipe(
|
||||
prompt=origin_prompt,
|
||||
cfg_scale=1.0, embedded_guidance=3.5,
|
||||
height=1280, width=960, num_inference_steps=30
|
||||
)
|
||||
image.save("style image.jpg")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
image = pipe(
|
||||
prompt="A piggy",
|
||||
cfg_scale=1.0, embedded_guidance=3.5,
|
||||
height=1280, width=960, num_inference_steps=30,
|
||||
ipadapter_images=[image], ipadapter_scale=0.7
|
||||
)
|
||||
image.save("A piggy.jpg")
|
||||
|
||||
Reference in New Issue
Block a user