mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
ipadapter for sdxl
This commit is contained in:
@@ -22,6 +22,8 @@ from .svd_unet import SVDUNet
|
|||||||
from .svd_vae_decoder import SVDVAEDecoder
|
from .svd_vae_decoder import SVDVAEDecoder
|
||||||
from .svd_vae_encoder import SVDVAEEncoder
|
from .svd_vae_encoder import SVDVAEEncoder
|
||||||
|
|
||||||
|
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterCLIPImageEmbedder
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||||
@@ -74,6 +76,13 @@ class ModelManager:
|
|||||||
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
||||||
return param_name in state_dict and len(state_dict) == 254
|
return param_name in state_dict and len(state_dict) == 254
|
||||||
|
|
||||||
|
def is_ipadapter_xl(self, state_dict):
|
||||||
|
return "image_proj" in state_dict and "ip_adapter" in state_dict
|
||||||
|
|
||||||
|
def is_ipadapter_xl_image_encoder(self, state_dict):
|
||||||
|
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
||||||
|
return param_name in state_dict
|
||||||
|
|
||||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
|
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
|
||||||
component_dict = {
|
component_dict = {
|
||||||
"image_encoder": SVDImageEncoder,
|
"image_encoder": SVDImageEncoder,
|
||||||
@@ -198,6 +207,22 @@ class ModelManager:
|
|||||||
self.model[component] = model
|
self.model[component] = model
|
||||||
self.model_path[component] = file_path
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_ipadapter_xl(self, state_dict, file_path=""):
|
||||||
|
component = "ipadapter_xl"
|
||||||
|
model = SDXLIpAdapter()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
|
def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
|
||||||
|
component = "ipadapter_xl_image_encoder"
|
||||||
|
model = IpAdapterCLIPImageEmbedder()
|
||||||
|
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||||
|
model.to(self.torch_dtype).to(self.device)
|
||||||
|
self.model[component] = model
|
||||||
|
self.model_path[component] = file_path
|
||||||
|
|
||||||
def search_for_embeddings(self, state_dict):
|
def search_for_embeddings(self, state_dict):
|
||||||
embeddings = []
|
embeddings = []
|
||||||
for k in state_dict:
|
for k in state_dict:
|
||||||
@@ -247,6 +272,10 @@ class ModelManager:
|
|||||||
self.load_RIFE(state_dict, file_path=file_path)
|
self.load_RIFE(state_dict, file_path=file_path)
|
||||||
elif self.is_translator(state_dict):
|
elif self.is_translator(state_dict):
|
||||||
self.load_translator(state_dict, file_path=file_path)
|
self.load_translator(state_dict, file_path=file_path)
|
||||||
|
elif self.is_ipadapter_xl(state_dict):
|
||||||
|
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
||||||
|
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
||||||
|
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
||||||
|
|
||||||
def load_models(self, file_path_list, lora_alphas=[]):
|
def load_models(self, file_path_list, lora_alphas=[]):
|
||||||
for file_path in file_path_list:
|
for file_path in file_path_list:
|
||||||
@@ -299,7 +328,9 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
|||||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||||
state_dict = torch.load(file_path, map_location="cpu")
|
state_dict = torch.load(file_path, map_location="cpu")
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
state_dict = {i: state_dict[i].to(torch_dtype) for i in state_dict}
|
for i in state_dict:
|
||||||
|
if isinstance(state_dict[i], torch.Tensor):
|
||||||
|
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,15 @@ class Attention(torch.nn.Module):
|
|||||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||||
|
|
||||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||||
|
batch_size = q.shape[0]
|
||||||
|
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||||
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None):
|
||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
encoder_hidden_states = hidden_states
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
@@ -41,6 +49,8 @@ class Attention(torch.nn.Module):
|
|||||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
if ipadapter_kwargs is not None:
|
||||||
|
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
hidden_states = hidden_states.to(q.dtype)
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
|
||||||
@@ -72,5 +82,5 @@ class Attention(torch.nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None):
|
||||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask)
|
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs)
|
||||||
@@ -47,15 +47,15 @@ class BasicTransformerBlock(torch.nn.Module):
|
|||||||
self.ff = torch.nn.Linear(dim * 4, dim)
|
self.ff = torch.nn.Linear(dim * 4, dim)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, encoder_hidden_states):
|
def forward(self, hidden_states, encoder_hidden_states, ipadapter_kwargs=None):
|
||||||
# 1. Self-Attention
|
# 1. Self-Attention
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None,)
|
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
# 2. Cross-Attention
|
# 2. Cross-Attention
|
||||||
norm_hidden_states = self.norm2(hidden_states)
|
norm_hidden_states = self.norm2(hidden_states)
|
||||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, ipadapter_kwargs=ipadapter_kwargs)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
# 3. Feed-forward
|
# 3. Feed-forward
|
||||||
@@ -150,6 +150,7 @@ class AttentionBlock(torch.nn.Module):
|
|||||||
hidden_states, time_emb, text_emb, res_stack,
|
hidden_states, time_emb, text_emb, res_stack,
|
||||||
cross_frame_attention=False,
|
cross_frame_attention=False,
|
||||||
tiled=False, tile_size=64, tile_stride=32,
|
tiled=False, tile_size=64, tile_stride=32,
|
||||||
|
ipadapter_kwargs_list={},
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
batch, _, height, width = hidden_states.shape
|
batch, _, height, width = hidden_states.shape
|
||||||
@@ -188,10 +189,11 @@ class AttentionBlock(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||||
else:
|
else:
|
||||||
for block in self.transformer_blocks:
|
for block_id, block in enumerate(self.transformer_blocks):
|
||||||
hidden_states = block(
|
hidden_states = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
encoder_hidden_states=encoder_hidden_states
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
ipadapter_kwargs=ipadapter_kwargs_list.get(block_id, None)
|
||||||
)
|
)
|
||||||
if cross_frame_attention:
|
if cross_frame_attention:
|
||||||
hidden_states = hidden_states.reshape(batch, height * width, inner_dim)
|
hidden_states = hidden_states.reshape(batch, height * width, inner_dim)
|
||||||
|
|||||||
121
diffsynth/models/sdxl_ipadapter.py
Normal file
121
diffsynth/models/sdxl_ipadapter.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
from .svd_image_encoder import SVDImageEncoder
|
||||||
|
from transformers import CLIPImageProcessor
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
|
||||||
|
self.image_processor = CLIPImageProcessor()
|
||||||
|
|
||||||
|
def forward(self, image):
|
||||||
|
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
|
||||||
|
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
|
||||||
|
return super().forward(pixel_values)
|
||||||
|
|
||||||
|
|
||||||
|
class IpAdapterImageProjModel(torch.nn.Module):
|
||||||
|
def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
|
||||||
|
super().__init__()
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||||
|
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||||
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
|
def forward(self, image_embeds):
|
||||||
|
clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
||||||
|
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
||||||
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class IpAdapterModule(torch.nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||||
|
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
ip_k = self.to_k_ip(hidden_states)
|
||||||
|
ip_v = self.to_v_ip(hidden_states)
|
||||||
|
return ip_k, ip_v
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLIpAdapter(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
|
||||||
|
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
|
||||||
|
self.image_proj = IpAdapterImageProjModel()
|
||||||
|
self.set_full_adapter()
|
||||||
|
|
||||||
|
def set_full_adapter(self):
|
||||||
|
map_list = sum([
|
||||||
|
[(7, i) for i in range(2)],
|
||||||
|
[(10, i) for i in range(2)],
|
||||||
|
[(15, i) for i in range(10)],
|
||||||
|
[(18, i) for i in range(10)],
|
||||||
|
[(25, i) for i in range(10)],
|
||||||
|
[(28, i) for i in range(10)],
|
||||||
|
[(31, i) for i in range(10)],
|
||||||
|
[(35, i) for i in range(2)],
|
||||||
|
[(38, i) for i in range(2)],
|
||||||
|
[(41, i) for i in range(2)],
|
||||||
|
[(21, i) for i in range(10)],
|
||||||
|
], [])
|
||||||
|
self.call_block_id = {i: j for j, i in enumerate(map_list)}
|
||||||
|
|
||||||
|
def set_less_adapter(self):
|
||||||
|
map_list = sum([
|
||||||
|
[(7, i) for i in range(2)],
|
||||||
|
[(10, i) for i in range(2)],
|
||||||
|
[(15, i) for i in range(10)],
|
||||||
|
[(18, i) for i in range(10)],
|
||||||
|
[(25, i) for i in range(10)],
|
||||||
|
[(28, i) for i in range(10)],
|
||||||
|
[(31, i) for i in range(10)],
|
||||||
|
[(35, i) for i in range(2)],
|
||||||
|
[(38, i) for i in range(2)],
|
||||||
|
[(41, i) for i in range(2)],
|
||||||
|
[(21, i) for i in range(10)],
|
||||||
|
], [])
|
||||||
|
self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
|
||||||
|
|
||||||
|
def forward(self, hidden_states, scale=1.0):
|
||||||
|
hidden_states = self.image_proj(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
||||||
|
ip_kv_dict = {}
|
||||||
|
for (block_id, transformer_id) in self.call_block_id:
|
||||||
|
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
|
||||||
|
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
||||||
|
if block_id not in ip_kv_dict:
|
||||||
|
ip_kv_dict[block_id] = {}
|
||||||
|
ip_kv_dict[block_id][transformer_id] = {
|
||||||
|
"ip_k": ip_k,
|
||||||
|
"ip_v": ip_v,
|
||||||
|
"scale": scale
|
||||||
|
}
|
||||||
|
return ip_kv_dict
|
||||||
|
|
||||||
|
def state_dict_converter(self):
|
||||||
|
return SDXLIpAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLIpAdapterStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict["ip_adapter"]:
|
||||||
|
names = name.split(".")
|
||||||
|
layer_id = str(int(names[0]) // 2)
|
||||||
|
name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
|
||||||
|
state_dict_[name_] = state_dict["ip_adapter"][name]
|
||||||
|
for name in state_dict["image_proj"]:
|
||||||
|
name_ = "image_proj." + name
|
||||||
|
state_dict_[name_] = state_dict["image_proj"][name]
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
return self.from_diffusers(state_dict)
|
||||||
|
|
||||||
@@ -25,11 +25,13 @@ class CLIPVisionEmbeddings(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SVDImageEncoder(torch.nn.Module):
|
class SVDImageEncoder(torch.nn.Module):
|
||||||
def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024):
|
def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024, num_heads=16, head_dim=80):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embeddings = CLIPVisionEmbeddings(embed_dim=embed_dim)
|
self.embeddings = CLIPVisionEmbeddings(embed_dim=embed_dim)
|
||||||
self.pre_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
self.pre_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
||||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=16, head_dim=80, use_quick_gelu=False) for _ in range(num_encoder_layers)])
|
self.encoders = torch.nn.ModuleList([
|
||||||
|
CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=num_heads, head_dim=head_dim, use_quick_gelu=False)
|
||||||
|
for _ in range(num_encoder_layers)])
|
||||||
self.post_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
self.post_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
||||||
self.visual_projection = torch.nn.Linear(embed_dim, projection_dim, bias=False)
|
self.visual_projection = torch.nn.Linear(embed_dim, projection_dim, bias=False)
|
||||||
|
|
||||||
@@ -78,7 +80,7 @@ class SVDImageEncoderStateDictConverter:
|
|||||||
if name == "vision_model.embeddings.class_embedding":
|
if name == "vision_model.embeddings.class_embedding":
|
||||||
param = state_dict[name].view(1, 1, -1)
|
param = state_dict[name].view(1, 1, -1)
|
||||||
elif name == "vision_model.embeddings.position_embedding.weight":
|
elif name == "vision_model.embeddings.position_embedding.weight":
|
||||||
param = state_dict[name].view(1, 257, 1280)
|
param = state_dict[name].unsqueeze(0)
|
||||||
state_dict_[rename_dict[name]] = param
|
state_dict_[rename_dict[name]] = param
|
||||||
elif name.startswith("vision_model.encoder.layers."):
|
elif name.startswith("vision_model.encoder.layers."):
|
||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
|||||||
@@ -119,6 +119,7 @@ def lets_dance_xl(
|
|||||||
add_text_embeds = None,
|
add_text_embeds = None,
|
||||||
timestep = None,
|
timestep = None,
|
||||||
encoder_hidden_states = None,
|
encoder_hidden_states = None,
|
||||||
|
ipadapter_kwargs_list = {},
|
||||||
controlnet_frames = None,
|
controlnet_frames = None,
|
||||||
unet_batch_size = 1,
|
unet_batch_size = 1,
|
||||||
controlnet_batch_size = 1,
|
controlnet_batch_size = 1,
|
||||||
@@ -151,7 +152,8 @@ def lets_dance_xl(
|
|||||||
for block_id, block in enumerate(unet.blocks):
|
for block_id, block in enumerate(unet.blocks):
|
||||||
hidden_states, time_emb, text_emb, res_stack = block(
|
hidden_states, time_emb, text_emb, res_stack = block(
|
||||||
hidden_states, time_emb, text_emb, res_stack,
|
hidden_states, time_emb, text_emb, res_stack,
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
|
||||||
)
|
)
|
||||||
# 4.2 AnimateDiff
|
# 4.2 AnimateDiff
|
||||||
if motion_modules is not None:
|
if motion_modules is not None:
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder
|
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterCLIPImageEmbedder
|
||||||
# TODO: SDXL ControlNet
|
# TODO: SDXL ControlNet
|
||||||
from ..prompts import SDXLPrompter
|
from ..prompts import SDXLPrompter
|
||||||
from ..schedulers import EnhancedDDIMScheduler
|
from ..schedulers import EnhancedDDIMScheduler
|
||||||
|
from .dancer import lets_dance_xl
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -22,6 +23,8 @@ class SDXLImagePipeline(torch.nn.Module):
|
|||||||
self.unet: SDXLUNet = None
|
self.unet: SDXLUNet = None
|
||||||
self.vae_decoder: SDXLVAEDecoder = None
|
self.vae_decoder: SDXLVAEDecoder = None
|
||||||
self.vae_encoder: SDXLVAEEncoder = None
|
self.vae_encoder: SDXLVAEEncoder = None
|
||||||
|
self.ipadapter_image_encoder: IpAdapterCLIPImageEmbedder = None
|
||||||
|
self.ipadapter: SDXLIpAdapter = None
|
||||||
# TODO: SDXL ControlNet
|
# TODO: SDXL ControlNet
|
||||||
|
|
||||||
def fetch_main_models(self, model_manager: ModelManager):
|
def fetch_main_models(self, model_manager: ModelManager):
|
||||||
@@ -35,6 +38,13 @@ class SDXLImagePipeline(torch.nn.Module):
|
|||||||
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||||
# TODO: SDXL ControlNet
|
# TODO: SDXL ControlNet
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_ipadapter(self, model_manager: ModelManager):
|
||||||
|
if "ipadapter_xl" in model_manager.model:
|
||||||
|
self.ipadapter = model_manager.ipadapter_xl
|
||||||
|
if "ipadapter_xl_image_encoder" in model_manager.model:
|
||||||
|
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
|
||||||
|
|
||||||
|
|
||||||
def fetch_prompter(self, model_manager: ModelManager):
|
def fetch_prompter(self, model_manager: ModelManager):
|
||||||
@@ -50,6 +60,7 @@ class SDXLImagePipeline(torch.nn.Module):
|
|||||||
pipe.fetch_main_models(model_manager)
|
pipe.fetch_main_models(model_manager)
|
||||||
pipe.fetch_prompter(model_manager)
|
pipe.fetch_prompter(model_manager)
|
||||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||||
|
pipe.fetch_ipadapter(model_manager)
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@@ -74,6 +85,7 @@ class SDXLImagePipeline(torch.nn.Module):
|
|||||||
clip_skip=1,
|
clip_skip=1,
|
||||||
clip_skip_2=2,
|
clip_skip_2=2,
|
||||||
input_image=None,
|
input_image=None,
|
||||||
|
ipadapter_images=None,
|
||||||
controlnet_image=None,
|
controlnet_image=None,
|
||||||
denoising_strength=1.0,
|
denoising_strength=1.0,
|
||||||
height=1024,
|
height=1024,
|
||||||
@@ -118,30 +130,38 @@ class SDXLImagePipeline(torch.nn.Module):
|
|||||||
|
|
||||||
# Prepare positional id
|
# Prepare positional id
|
||||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||||
|
|
||||||
|
# IP-Adapter
|
||||||
|
if ipadapter_images is not None:
|
||||||
|
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||||
|
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding)
|
||||||
|
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
||||||
|
else:
|
||||||
|
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||||
|
|
||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
|
noise_pred_posi = lets_dance_xl(
|
||||||
|
self.unet,
|
||||||
|
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
|
||||||
|
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||||
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
||||||
|
)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_posi = self.unet(
|
noise_pred_nega = lets_dance_xl(
|
||||||
latents, timestep, prompt_emb_posi,
|
self.unet,
|
||||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
|
||||||
)
|
|
||||||
noise_pred_nega = self.unet(
|
|
||||||
latents, timestep, prompt_emb_nega,
|
|
||||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||||
|
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
||||||
)
|
)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
noise_pred = self.unet(
|
noise_pred = noise_pred_posi
|
||||||
latents, timestep, prompt_emb_posi,
|
|
||||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
|
||||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
|
||||||
)
|
|
||||||
|
|
||||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||||
|
|
||||||
|
|||||||
36
examples/sdxl_ipadapter.py
Normal file
36
examples/sdxl_ipadapter.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from diffsynth import ModelManager, SDXLImagePipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Download models
|
||||||
|
# `models/stable_diffusion_xl/sd_xl_base_1.0.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors)
|
||||||
|
# `models/IpAdapter/image_encoder/model.safetensors`: [link](https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/model.safetensors)
|
||||||
|
# `models/IpAdapter/ip-adapter_sdxl.bin`: [link](https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl.safetensors)
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||||
|
model_manager.load_models([
|
||||||
|
"models/stable_diffusion_xl/sd_xl_base_1.0.safetensors",
|
||||||
|
"models/IpAdapter/image_encoder/model.safetensors",
|
||||||
|
"models/IpAdapter/ip-adapter_sdxl.bin"
|
||||||
|
])
|
||||||
|
pipe = SDXLImagePipeline.from_model_manager(model_manager)
|
||||||
|
pipe.ipadapter.set_less_adapter()
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
style_image = pipe(
|
||||||
|
prompt="Starry Night, blue sky, by van Gogh",
|
||||||
|
negative_prompt="dark, gray",
|
||||||
|
cfg_scale=5,
|
||||||
|
height=1024, width=1024, num_inference_steps=30,
|
||||||
|
)
|
||||||
|
style_image.save("style_image.jpg")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt="a cat",
|
||||||
|
negative_prompt="",
|
||||||
|
cfg_scale=5,
|
||||||
|
height=1024, width=1024, num_inference_steps=30,
|
||||||
|
ipadapter_images=[style_image]
|
||||||
|
)
|
||||||
|
image.save("transferred_image.jpg")
|
||||||
Reference in New Issue
Block a user