From a6d6553ceea678e58b9dbfafd3d0023cda12d392 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 24 Oct 2024 17:36:22 +0800 Subject: [PATCH] bug fix --- diffsynth/pipelines/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index f8f7178..b968bb6 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -47,8 +47,11 @@ class BasePipeline(torch.nn.Module): return value - def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs={}, special_local_kwargs_list=None): - noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) + def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None): + if special_kwargs is None: + noise_pred_global = inference_callback(prompt_emb_global) + else: + noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) if special_local_kwargs_list is None: noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] else: