mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
support sdxl controlnet union
This commit is contained in:
@@ -25,7 +25,7 @@ class SDXLImagePipeline(BasePipeline):
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# self.controlnet: MultiControlNetManager = None (TODO)
|
||||
self.controlnet: MultiControlNetManager = None
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
|
||||
@@ -43,7 +43,16 @@ class SDXLImagePipeline(BasePipeline):
|
||||
self.vae_decoder = model_manager.fetch_model("sdxl_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("sdxl_vae_encoder")
|
||||
|
||||
# ControlNets (TODO)
|
||||
# ControlNets
|
||||
controlnet_units = []
|
||||
for config in controlnet_config_units:
|
||||
controlnet_unit = ControlNetUnit(
|
||||
Annotator(config.processor_id, device=self.device),
|
||||
model_manager.fetch_model("sdxl_controlnet", config.model_path),
|
||||
config.scale
|
||||
)
|
||||
controlnet_units.append(controlnet_unit)
|
||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||
|
||||
# IP-Adapters
|
||||
self.ipadapter = model_manager.fetch_model("sdxl_ipadapter")
|
||||
@@ -150,8 +159,13 @@ class SDXLImagePipeline(BasePipeline):
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
|
||||
|
||||
# Prepare ControlNets (TODO)
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
# Prepare ControlNets
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = self.controlnet.process_image(controlnet_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
controlnet_image = controlnet_image.unsqueeze(1)
|
||||
controlnet_kwargs = {"controlnet_frames": controlnet_image}
|
||||
else:
|
||||
controlnet_kwargs = {"controlnet_frames": None}
|
||||
|
||||
# Prepare extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
@@ -162,14 +176,14 @@ class SDXLImagePipeline(BasePipeline):
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet, motion_modules=None, controlnet=None,
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, **extra_input,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
||||
device=self.device,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet, motion_modules=None, controlnet=None,
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, **extra_input,
|
||||
**prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
||||
device=self.device,
|
||||
|
||||
Reference in New Issue
Block a user