support sdxl controlnet union

This commit is contained in:
Artiprocher
2024-08-01 10:01:39 +08:00
parent 60d7bb52d6
commit 6f79fd6d77
10 changed files with 408 additions and 17 deletions

View File

@@ -136,6 +136,34 @@ def lets_dance_xl(
device = "cuda",
vram_limit_level = 0,
):
# 1. ControlNet
controlnet_insert_block_id = 22
if controlnet is not None and controlnet_frames is not None:
res_stacks = []
# process controlnet frames with batch
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
res_stack = controlnet(
sample[batch_id: batch_id_],
timestep,
encoder_hidden_states[batch_id: batch_id_],
controlnet_frames[:, batch_id: batch_id_],
add_time_id=add_time_id,
add_text_embeds=add_text_embeds,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
unet=unet, # for Kolors, some modules in ControlNets will be replaced.
)
if vram_limit_level >= 1:
res_stack = [res.cpu() for res in res_stack]
res_stacks.append(res_stack)
# concat the residual
additional_res_stack = []
for i in range(len(res_stacks[0])):
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
additional_res_stack.append(res)
else:
additional_res_stack = None
# 2. time
t_emb = unet.time_proj(timestep).to(sample.dtype)
t_emb = unet.time_embedding(t_emb)
@@ -156,11 +184,31 @@ def lets_dance_xl(
# 4. blocks
for block_id, block in enumerate(unet.blocks):
hidden_states, time_emb, text_emb, res_stack = block(
hidden_states, time_emb, text_emb, res_stack,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
)
# 4.1 UNet
if isinstance(block, PushBlock):
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].cpu()
elif isinstance(block, PopBlock):
if vram_limit_level>=1:
res_stack[-1] = res_stack[-1].to(device)
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
else:
hidden_states_input = hidden_states
hidden_states_output = []
for batch_id in range(0, sample.shape[0], unet_batch_size):
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
hidden_states, _, _, _ = block(
hidden_states_input[batch_id: batch_id_],
time_emb,
text_emb[batch_id: batch_id_],
res_stack,
cross_frame_attention=cross_frame_attention,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
)
hidden_states_output.append(hidden_states)
hidden_states = torch.concat(hidden_states_output, dim=0)
# 4.2 AnimateDiff
if motion_modules is not None:
if block_id in motion_modules.call_block_id:
@@ -169,6 +217,10 @@ def lets_dance_xl(
hidden_states, time_emb, text_emb, res_stack,
batch_size=1
)
# 4.3 ControlNet
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
hidden_states += additional_res_stack.pop().to(device)
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
# 5. output
hidden_states = unet.conv_norm_out(hidden_states)