diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index f8f7178..b968bb6 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -47,8 +47,11 @@ class BasePipeline(torch.nn.Module): return value - def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs={}, special_local_kwargs_list=None): - noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) + def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback, special_kwargs=None, special_local_kwargs_list=None): + if special_kwargs is None: + noise_pred_global = inference_callback(prompt_emb_global) + else: + noise_pred_global = inference_callback(prompt_emb_global, special_kwargs) if special_local_kwargs_list is None: noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] else: