support sdxl controlnet union

This commit is contained in:
Artiprocher
2024-08-01 10:01:39 +08:00
parent 60d7bb52d6
commit 6f79fd6d77
10 changed files with 408 additions and 17 deletions

View File

@@ -16,6 +16,7 @@ from ..models.sd3_vae_decoder import SD3VAEDecoder
from ..models.sd3_vae_encoder import SD3VAEEncoder
from ..models.sd_controlnet import SDControlNet
from ..models.sdxl_controlnet import SDXLControlNetUnion
from ..models.sd_motion import SDMotionModel
from ..models.sdxl_motion import SDXLMotionModel
@@ -60,6 +61,7 @@ model_loader_configs = [
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.

View File

@@ -37,13 +37,14 @@ class MultiControlNetManager:
def __call__(
self,
sample, timestep, encoder_hidden_states, conditionings,
tiled=False, tile_size=64, tile_stride=32
tiled=False, tile_size=64, tile_stride=32, **kwargs
):
res_stack = None
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
res_stack_ = model(
sample, timestep, encoder_hidden_states, conditioning,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
processor_id=processor.processor_id
)
res_stack_ = [res * scale for res in res_stack_]
if res_stack is None:

View File

@@ -47,5 +47,6 @@ class Annotator:
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
image = image.resize((width, height))
image.save("/mnt/zhongjie/DiffSynth-Studio-kolors/DiffSynth-Studio/input.jpg")
return image

View File

@@ -23,6 +23,7 @@ from .sd3_vae_decoder import SD3VAEDecoder
from .sd3_vae_encoder import SD3VAEEncoder
from .sd_controlnet import SDControlNet
from .sdxl_controlnet import SDXLControlNetUnion
from .sd_motion import SDMotionModel
from .sdxl_motion import SDXLMotionModel

View File

@@ -97,6 +97,7 @@ class SDControlNet(torch.nn.Module):
self,
sample, timestep, encoder_hidden_states, conditioning,
tiled=False, tile_size=64, tile_stride=32,
**kwargs
):
# 1. time
time_emb = self.time_proj(timestep).to(sample.dtype)

View File

@@ -0,0 +1,318 @@
import torch
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
from .sdxl_unet import SDXLUNet
from .tiler import TileWorker
from .sd_controlnet import ControlNetConditioningLayer
from collections import OrderedDict
class QuickGELU(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(torch.nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = torch.nn.MultiheadAttention(d_model, n_head)
self.ln_1 = torch.nn.LayerNorm(d_model)
self.mlp = torch.nn.Sequential(OrderedDict([
("c_fc", torch.nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", torch.nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = torch.nn.LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class SDXLControlNetUnion(torch.nn.Module):
def __init__(self, global_pool=False):
super().__init__()
self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential(
torch.nn.Linear(320, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.add_time_proj = Timesteps(256)
self.add_time_embedding = torch.nn.Sequential(
torch.nn.Linear(2816, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.control_type_proj = Timesteps(256)
self.control_type_embedding = torch.nn.Sequential(
torch.nn.Linear(256 * 8, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
self.controlnet_transformer = ResidualAttentionBlock(320, 8)
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
self.spatial_ch_projs = torch.nn.Linear(320, 320)
self.blocks = torch.nn.ModuleList([
# DownBlock2D
ResnetBlock(320, 320, 1280),
PushBlock(),
ResnetBlock(320, 320, 1280),
PushBlock(),
DownSampler(320),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(320, 640, 1280),
AttentionBlock(10, 64, 640, 2, 2048),
PushBlock(),
ResnetBlock(640, 640, 1280),
AttentionBlock(10, 64, 640, 2, 2048),
PushBlock(),
DownSampler(640),
PushBlock(),
# CrossAttnDownBlock2D
ResnetBlock(640, 1280, 1280),
AttentionBlock(20, 64, 1280, 10, 2048),
PushBlock(),
ResnetBlock(1280, 1280, 1280),
AttentionBlock(20, 64, 1280, 10, 2048),
PushBlock(),
# UNetMidBlock2DCrossAttn
ResnetBlock(1280, 1280, 1280),
AttentionBlock(20, 64, 1280, 10, 2048),
ResnetBlock(1280, 1280, 1280),
PushBlock()
])
self.controlnet_blocks = torch.nn.ModuleList([
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
])
self.global_pool = global_pool
# 0 -- openpose
# 1 -- depth
# 2 -- hed/pidi/scribble/ted
# 3 -- canny/lineart/anime_lineart/mlsd
# 4 -- normal
# 5 -- segment
# 6 -- tile
# 7 -- repaint
self.task_id = {
"openpose": 0,
"depth": 1,
"softedge": 2,
"canny": 3,
"lineart": 3,
"lineart_anime": 3,
"tile": 6,
"inpaint": 7
}
def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
controlnet_cond = self.controlnet_conv_in(conditioning)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[task_id]
x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
x = self.controlnet_transformer(x)
alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
controlnet_cond_fuser = controlnet_cond + alpha
hidden_states = hidden_states + controlnet_cond_fuser
return hidden_states
def forward(
self,
sample, timestep, encoder_hidden_states,
conditioning, processor_id, add_time_id, add_text_embeds,
tiled=False, tile_size=64, tile_stride=32,
unet:SDXLUNet=None,
**kwargs
):
task_id = self.task_id[processor_id]
# 1. time
t_emb = self.time_proj(timestep).to(sample.dtype)
t_emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(add_time_id)
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(sample.dtype)
if unet is not None and unet.is_kolors:
add_embeds = unet.add_time_embedding(add_embeds)
else:
add_embeds = self.add_time_embedding(add_embeds)
control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
control_type[:, task_id] = 1
control_embeds = self.control_type_proj(control_type.flatten())
control_embeds = control_embeds.reshape((sample.shape[0], -1))
control_embeds = control_embeds.to(sample.dtype)
control_embeds = self.control_type_embedding(control_embeds)
time_emb = t_emb + add_embeds + control_embeds
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample)
hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
text_emb = encoder_hidden_states
if unet is not None and unet.is_kolors:
text_emb = unet.text_intermediate_proj(text_emb)
res_stack = [hidden_states]
# 3. blocks
for i, block in enumerate(self.blocks):
if tiled and not isinstance(block, PushBlock):
_, _, inter_height, _ = hidden_states.shape
resize_scale = inter_height / height
hidden_states = TileWorker().tiled_forward(
lambda x: block(x, time_emb, text_emb, res_stack)[0],
hidden_states,
int(tile_size * resize_scale),
int(tile_stride * resize_scale),
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)
else:
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
# 4. ControlNet blocks
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
# pool
if self.global_pool:
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
return controlnet_res_stack
@staticmethod
def state_dict_converter():
return SDXLControlNetUnionStateDictConverter()
class SDXLControlNetUnionStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
# architecture
block_types = [
"ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
"ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
]
# controlnet_rename_dict
controlnet_rename_dict = {
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
"control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
"control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
"control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
"control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
}
# Rename each parameter
name_list = sorted([name for name in state_dict])
rename_dict = {}
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
for name in name_list:
names = name.split(".")
if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
pass
elif name in controlnet_rename_dict:
names = controlnet_rename_dict[name].split(".")
elif names[0] == "controlnet_down_blocks":
names[0] = "controlnet_blocks"
elif names[0] == "controlnet_mid_block":
names = ["controlnet_blocks", "9", names[-1]]
elif names[0] in ["time_embedding", "add_embedding"]:
if names[0] == "add_embedding":
names[0] = "add_time_embedding"
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
elif names[0] == "control_add_embedding":
names[0] = "control_type_embedding"
elif names[0] == "transformer_layes":
names[0] = "controlnet_transformer"
names.pop(1)
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
if names[0] == "mid_block":
names.insert(1, "0")
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
block_type_with_id = ".".join(names[:4])
if block_type_with_id != last_block_type_with_id[block_type]:
block_id[block_type] += 1
last_block_type_with_id[block_type] = block_type_with_id
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
block_id[block_type] += 1
block_type_with_id = ".".join(names[:4])
names = ["blocks", str(block_id[block_type])] + names[4:]
if "ff" in names:
ff_index = names.index("ff")
component = ".".join(names[ff_index:ff_index+3])
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
names = names[:ff_index] + [component] + names[ff_index+3:]
if "to_out" in names:
names.pop(names.index("to_out") + 1)
else:
print(name, state_dict[name].shape)
# raise ValueError(f"Unknown parameters: {name}")
rename_dict[name] = ".".join(names)
# Convert state_dict
state_dict_ = {}
for name, param in state_dict.items():
if name not in rename_dict:
continue
if ".proj_in." in name or ".proj_out." in name:
param = param.squeeze()
state_dict_[rename_dict[name]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -83,6 +83,8 @@ class SDXLUNet(torch.nn.Module):
self.conv_act = torch.nn.SiLU()
self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=3, padding=1)
self.is_kolors = is_kolors
def forward(
self,
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,

View File

@@ -136,6 +136,34 @@ def lets_dance_xl(
device = "cuda",
vram_limit_level = 0,
):
# 1. ControlNet
controlnet_insert_block_id = 22
if controlnet is not None and controlnet_frames is not None:
res_stacks = []
# process controlnet frames with batch
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
res_stack = controlnet(
sample[batch_id: batch_id_],
timestep,
encoder_hidden_states[batch_id: batch_id_],
controlnet_frames[:, batch_id: batch_id_],
add_time_id=add_time_id,
add_text_embeds=add_text_embeds,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
unet=unet, # for Kolors, some modules in ControlNets will be replaced.
)
if vram_limit_level >= 1:
res_stack = [res.cpu() for res in res_stack]
res_stacks.append(res_stack)
# concat the residual
additional_res_stack = []
for i in range(len(res_stacks[0])):
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
additional_res_stack.append(res)
else:
additional_res_stack = None
# 2. time
t_emb = unet.time_proj(timestep).to(sample.dtype)
t_emb = unet.time_embedding(t_emb)
@@ -156,11 +184,31 @@ def lets_dance_xl(
# 4. blocks
for block_id, block in enumerate(unet.blocks):
hidden_states, time_emb, text_emb, res_stack = block(
hidden_states, time_emb, text_emb, res_stack,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
)
# 4.1 UNet
if isinstance(block, PushBlock):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].cpu()
elif isinstance(block, PopBlock):
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].to(device)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
else:
hidden_states_input = hidden_states
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff
if motion_modules is not None:
if block_id in motion_modules.call_block_id:
@@ -169,6 +217,10 @@ def lets_dance_xl(
hidden_states, time_emb, text_emb, res_stack,
batch_size=1
)
# 4.3 ControlNet
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
hidden_states += additional_res_stack.pop().to(device)
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
# 5. output
hidden_states = unet.conv_norm_out(hidden_states)

View File

@@ -29,8 +29,7 @@ class SD3ImagePipeline(BasePipeline):
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
if "sd3_text_encoder_3" in model_manager.model:
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
self.dit = model_manager.fetch_model("sd3_dit")
self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")

View File

@@ -25,7 +25,7 @@ class SDXLImagePipeline(BasePipeline):
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
# self.controlnet: MultiControlNetManager = None (TODO)
self.controlnet: MultiControlNetManager = None
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
@@ -43,7 +43,16 @@ class SDXLImagePipeline(BasePipeline):
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
# ControlNets (TODO)
# ControlNets
controlnet_units = []
for config in controlnet_config_units:
controlnet_unit = ControlNetUnit(
Annotator(config.processor_id, device=self.device),
model_manager.fetch_model("sdxl_controlnet", config.model_path),
config.scale
)
controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units)
# IP-Adapters
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
@@ -150,8 +159,13 @@ class SDXLImagePipeline(BasePipeline):
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Prepare ControlNets (TODO)
controlnet_kwargs = {"controlnet_frames": None}
# Prepare ControlNets
if controlnet_image is not None:
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
controlnet_image = controlnet_image.unsqueeze(1)
controlnet_kwargs = {"controlnet_frames": controlnet_image}
else:
controlnet_kwargs = {"controlnet_frames": None}
# Prepare extra input
extra_input = self.prepare_extra_input(latents)
@@ -162,14 +176,14 @@ class SDXLImagePipeline(BasePipeline):
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet, motion_modules=None, controlnet=None,
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **extra_input,
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
device=self.device,
)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet, motion_modules=None, controlnet=None,
self.unet, motion_modules=None, controlnet=self.controlnet,
sample=latents, timestep=timestep, **extra_input,
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
device=self.device,