From f189f9f1bea9fe185ef43601fa8675712a639ef4 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 2 Aug 2024 10:31:25 +0800 Subject: [PATCH] update UI --- diffsynth/pipelines/base.py | 19 +++++++++++++++++ diffsynth/pipelines/hunyuan_image.py | 10 ++++++--- diffsynth/pipelines/sd3_image.py | 7 +++++- diffsynth/pipelines/sd_image.py | 7 +++++- diffsynth/pipelines/sdxl_image.py | 8 ++++++- pages/1_Image_Creator.py | 32 ++++++++++++++++++++++++++++ 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index cb83527..8b99c82 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -31,4 +31,23 @@ class BasePipeline(torch.nn.Module): video = vae_output.cpu().permute(1, 2, 0).numpy() video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video] return video + + + def merge_latents(self, value, latents, masks, scales): + height, width = value.shape[-2:] + weight = torch.ones_like(value) + for latent, mask, scale in zip(latents, masks, scales): + mask = self.preprocess_image(mask.resize((height, width))).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, latent.shape[1], 1, 1) + value[mask] += latent[mask] * scale + weight[mask] += scale + value /= weight + return value + + + def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback): + noise_pred_global = inference_callback(prompt_emb_global) + noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] + noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) + return noise_pred \ No newline at end of file diff --git a/diffsynth/pipelines/hunyuan_image.py b/diffsynth/pipelines/hunyuan_image.py index 241f772..9181431 100644 --- a/diffsynth/pipelines/hunyuan_image.py +++ b/diffsynth/pipelines/hunyuan_image.py @@ -209,6 +209,9 @@ class HunyuanDiTImagePipeline(BasePipeline): def __call__( self, prompt, + local_prompts=[], + masks=[], + mask_scales=[], negative_prompt="", cfg_scale=7.5, clip_skip=1, @@ -241,6 +244,7 @@ class HunyuanDiTImagePipeline(BasePipeline): prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) if cfg_scale != 1.0: prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) + prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts] # Prepare positional id extra_input = self.prepare_extra_input(latents, tiled, tile_size) @@ -250,9 +254,9 @@ class HunyuanDiTImagePipeline(BasePipeline): timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device) # Positive side - noise_pred_posi = self.dit( - latents, timestep=timestep, **prompt_emb_posi, **extra_input, - ) + inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input) + noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) + if cfg_scale != 1.0: # Negative side noise_pred_nega = self.dit( diff --git a/diffsynth/pipelines/sd3_image.py b/diffsynth/pipelines/sd3_image.py index f52c2ed..d7dd371 100644 --- a/diffsynth/pipelines/sd3_image.py +++ b/diffsynth/pipelines/sd3_image.py @@ -73,6 +73,9 @@ class SD3ImagePipeline(BasePipeline): def __call__( self, prompt, + local_prompts=[], + masks=[], + mask_scales=[], negative_prompt="", cfg_scale=7.5, input_image=None, @@ -104,15 +107,17 @@ class SD3ImagePipeline(BasePipeline): # Encode prompts prompt_emb_posi = self.encode_prompt(prompt, positive=True) prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) + prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts] # Denoise for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) # Classifier-free guidance - noise_pred_posi = self.dit( + inference_callback = lambda prompt_emb_posi: self.dit( latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, ) + noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) noise_pred_nega = self.dit( latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, ) diff --git a/diffsynth/pipelines/sd_image.py b/diffsynth/pipelines/sd_image.py index 0b0d238..016720d 100644 --- a/diffsynth/pipelines/sd_image.py +++ b/diffsynth/pipelines/sd_image.py @@ -90,6 +90,9 @@ class SDImagePipeline(BasePipeline): def __call__( self, prompt, + local_prompts=[], + masks=[], + mask_scales=[], negative_prompt="", cfg_scale=7.5, clip_skip=1, @@ -125,6 +128,7 @@ class SDImagePipeline(BasePipeline): # Encode prompts prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True) prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False) + prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts] # IP-Adapter if ipadapter_images is not None: @@ -147,12 +151,13 @@ class SDImagePipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(self.device) # Classifier-free guidance - noise_pred_posi = lets_dance( + inference_callback = lambda prompt_emb_posi: lets_dance( self.unet, motion_modules=None, controlnet=self.controlnet, sample=latents, timestep=timestep, **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi, device=self.device, ) + noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) noise_pred_nega = lets_dance( self.unet, motion_modules=None, controlnet=self.controlnet, sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega, diff --git a/diffsynth/pipelines/sdxl_image.py b/diffsynth/pipelines/sdxl_image.py index c214ebd..2cd73d8 100644 --- a/diffsynth/pipelines/sdxl_image.py +++ b/diffsynth/pipelines/sdxl_image.py @@ -109,6 +109,9 @@ class SDXLImagePipeline(BasePipeline): def __call__( self, prompt, + local_prompts=[], + masks=[], + mask_scales=[], negative_prompt="", cfg_scale=7.5, clip_skip=1, @@ -146,6 +149,7 @@ class SDXLImagePipeline(BasePipeline): # Encode prompts prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False) + prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts] # IP-Adapter if ipadapter_images is not None: @@ -175,12 +179,14 @@ class SDXLImagePipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(self.device) # Classifier-free guidance - noise_pred_posi = lets_dance_xl( + inference_callback = lambda prompt_emb_posi: lets_dance_xl( self.unet, motion_modules=None, controlnet=self.controlnet, sample=latents, timestep=timestep, **extra_input, **prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi, device=self.device, ) + noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback) + if cfg_scale != 1.0: noise_pred_nega = lets_dance_xl( self.unet, motion_modules=None, controlnet=self.controlnet, diff --git a/pages/1_Image_Creator.py b/pages/1_Image_Creator.py index 9fb49ca..2d13782 100644 --- a/pages/1_Image_Creator.py +++ b/pages/1_Image_Creator.py @@ -255,6 +255,37 @@ with column_input: key="canvas" ) + num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0) + local_prompts, masks, mask_scales = [], [], [] + white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255) + for painter_tab_id in range(num_painter_layer): + with st.expander(f"Painter layer {painter_tab_id}", expanded=True): + enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True) + local_prompt = st.text_area(f"Prompt {painter_tab_id}") + mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0) + stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100) + canvas_result_local = st_canvas( + fill_color="#000000", + stroke_width=stroke_width, + stroke_color="#000000", + background_color="rgba(255, 255, 255, 0)", + background_image=white_board, + update_streamlit=True, + height=512, + width=512, + drawing_mode="freedraw", + key=f"canvas_{painter_tab_id}" + ) + if enable_local_prompt: + local_prompts.append(local_prompt) + if canvas_result_local.image_data is not None: + mask = apply_stroke_to_image(canvas_result_local.image_data, white_board) + else: + mask = white_board + mask = Image.fromarray(255 - np.array(mask)) + masks.append(mask) + mask_scales.append(mask_scale) + with column_output: run_button = st.button("Generate image", type="primary") @@ -282,6 +313,7 @@ with column_output: progress_bar_st = st.progress(0.0) image = pipeline( prompt, negative_prompt=negative_prompt, + local_prompts=local_prompts, masks=masks, mask_scales=mask_scales, cfg_scale=cfg_scale, num_inference_steps=num_inference_steps, height=height, width=width, input_image=input_image, denoising_strength=denoising_strength,