mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
support sdxl controlnet union
This commit is contained in:
@@ -136,6 +136,34 @@ def lets_dance_xl(
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 1. ControlNet
|
||||
controlnet_insert_block_id = 22
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
res_stacks = []
|
||||
# process controlnet frames with batch
|
||||
for batch_id in range(0, sample.shape[0], controlnet_batch_size):
|
||||
batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0])
|
||||
res_stack = controlnet(
|
||||
sample[batch_id: batch_id_],
|
||||
timestep,
|
||||
encoder_hidden_states[batch_id: batch_id_],
|
||||
controlnet_frames[:, batch_id: batch_id_],
|
||||
add_time_id=add_time_id,
|
||||
add_text_embeds=add_text_embeds,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
unet=unet, # for Kolors, some modules in ControlNets will be replaced.
|
||||
)
|
||||
if vram_limit_level >= 1:
|
||||
res_stack = [res.cpu() for res in res_stack]
|
||||
res_stacks.append(res_stack)
|
||||
# concat the residual
|
||||
additional_res_stack = []
|
||||
for i in range(len(res_stacks[0])):
|
||||
res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0)
|
||||
additional_res_stack.append(res)
|
||||
else:
|
||||
additional_res_stack = None
|
||||
|
||||
# 2. time
|
||||
t_emb = unet.time_proj(timestep).to(sample.dtype)
|
||||
t_emb = unet.time_embedding(t_emb)
|
||||
@@ -156,11 +184,31 @@ def lets_dance_xl(
|
||||
|
||||
# 4. blocks
|
||||
for block_id, block in enumerate(unet.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {})
|
||||
)
|
||||
# 4.1 UNet
|
||||
if isinstance(block, PushBlock):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
if vram_limit_level>=1:
|
||||
res_stack[-1] = res_stack[-1].cpu()
|
||||
elif isinstance(block, PopBlock):
|
||||
if vram_limit_level>=1:
|
||||
res_stack[-1] = res_stack[-1].to(device)
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
else:
|
||||
hidden_states_input = hidden_states
|
||||
hidden_states_output = []
|
||||
for batch_id in range(0, sample.shape[0], unet_batch_size):
|
||||
batch_id_ = min(batch_id + unet_batch_size, sample.shape[0])
|
||||
hidden_states, _, _, _ = block(
|
||||
hidden_states_input[batch_id: batch_id_],
|
||||
time_emb,
|
||||
text_emb[batch_id: batch_id_],
|
||||
res_stack,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, {}),
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
)
|
||||
hidden_states_output.append(hidden_states)
|
||||
hidden_states = torch.concat(hidden_states_output, dim=0)
|
||||
# 4.2 AnimateDiff
|
||||
if motion_modules is not None:
|
||||
if block_id in motion_modules.call_block_id:
|
||||
@@ -169,6 +217,10 @@ def lets_dance_xl(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
batch_size=1
|
||||
)
|
||||
# 4.3 ControlNet
|
||||
if block_id == controlnet_insert_block_id and additional_res_stack is not None:
|
||||
hidden_states += additional_res_stack.pop().to(device)
|
||||
res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)]
|
||||
|
||||
# 5. output
|
||||
hidden_states = unet.conv_norm_out(hidden_states)
|
||||
|
||||
@@ -29,8 +29,7 @@ class SD3ImagePipeline(BasePipeline):
|
||||
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
|
||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("sd3_text_encoder_2")
|
||||
if "sd3_text_encoder_3" in model_manager.model:
|
||||
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
|
||||
self.text_encoder_3 = model_manager.fetch_model("sd3_text_encoder_3")
|
||||
self.dit = model_manager.fetch_model("sd3_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("sd3_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sd3_vae_encoder")
|
||||
|
||||
@@ -25,7 +25,7 @@ class SDXLImagePipeline(BasePipeline):
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# self.controlnet: MultiControlNetManager = None (TODO)
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
|
||||
@@ -43,7 +43,16 @@ class SDXLImagePipeline(BasePipeline):
|
||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
||||
|
||||
# ControlNets (TODO)
|
||||
# ControlNets
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id, device=self.device),
|
||||
model_manager.fetch_model("sdxl_controlnet", config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
|
||||
@@ -150,8 +159,13 @@ class SDXLImagePipeline(BasePipeline):
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
|
||||
# Prepare ControlNets (TODO)
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
# Prepare ControlNets
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
controlnet_image = controlnet_image.unsqueeze(1)
|
||||
controlnet_kwargs = {"controlnet_frames": controlnet_image}
|
||||
else:
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Prepare extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
@@ -162,14 +176,14 @@ class SDXLImagePipeline(BasePipeline):
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet, motion_modules=None, controlnet=None,
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, **extra_input,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
||||
device=self.device,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet, motion_modules=None, controlnet=None,
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, **extra_input,
|
||||
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
||||
device=self.device,
|
||||
|
||||
Reference in New Issue
Block a user