mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
update UI
This commit is contained in:
@@ -31,4 +31,23 @@ class BasePipeline(torch.nn.Module):
|
||||
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
||||
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
||||
return video
|
||||
|
||||
|
||||
def merge_latents(self, value, latents, masks, scales):
|
||||
height, width = value.shape[-2:]
|
||||
weight = torch.ones_like(value)
|
||||
for latent, mask, scale in zip(latents, masks, scales):
|
||||
mask = self.preprocess_image(mask.resize((height, width))).mean(dim=1, keepdim=True) > 0
|
||||
mask = mask.repeat(1, latent.shape[1], 1, 1)
|
||||
value[mask] += latent[mask] * scale
|
||||
weight[mask] += scale
|
||||
value /= weight
|
||||
return value
|
||||
|
||||
|
||||
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback):
|
||||
noise_pred_global = inference_callback(prompt_emb_global)
|
||||
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
||||
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
||||
return noise_pred
|
||||
|
||||
@@ -209,6 +209,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
local_prompts=[],
|
||||
masks=[],
|
||||
mask_scales=[],
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
@@ -241,6 +244,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
|
||||
|
||||
# Prepare positional id
|
||||
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
|
||||
@@ -250,9 +254,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
||||
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Positive side
|
||||
noise_pred_posi = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_posi, **extra_input,
|
||||
)
|
||||
inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
# Negative side
|
||||
noise_pred_nega = self.dit(
|
||||
|
||||
@@ -73,6 +73,9 @@ class SD3ImagePipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
local_prompts=[],
|
||||
masks=[],
|
||||
mask_scales=[],
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
input_image=None,
|
||||
@@ -104,15 +107,17 @@ class SD3ImagePipeline(BasePipeline):
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = self.dit(
|
||||
inference_callback = lambda prompt_emb_posi: self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
noise_pred_nega = self.dit(
|
||||
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
|
||||
)
|
||||
|
||||
@@ -90,6 +90,9 @@ class SDImagePipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
local_prompts=[],
|
||||
masks=[],
|
||||
mask_scales=[],
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
@@ -125,6 +128,7 @@ class SDImagePipeline(BasePipeline):
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts]
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
@@ -147,12 +151,13 @@ class SDImagePipeline(BasePipeline):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance(
|
||||
inference_callback = lambda prompt_emb_posi: lets_dance(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
noise_pred_nega = lets_dance(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
||||
|
||||
@@ -109,6 +109,9 @@ class SDXLImagePipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
local_prompts=[],
|
||||
masks=[],
|
||||
mask_scales=[],
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
@@ -146,6 +149,7 @@ class SDXLImagePipeline(BasePipeline):
|
||||
# Encode prompts
|
||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
|
||||
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
@@ -175,12 +179,14 @@ class SDXLImagePipeline(BasePipeline):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
inference_callback = lambda prompt_emb_posi: lets_dance_xl(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
sample=latents, timestep=timestep, **extra_input,
|
||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
||||
device=self.device,
|
||||
)
|
||||
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||
|
||||
@@ -255,6 +255,37 @@ with column_input:
|
||||
key="canvas"
|
||||
)
|
||||
|
||||
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
|
||||
local_prompts, masks, mask_scales = [], [], []
|
||||
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
|
||||
for painter_tab_id in range(num_painter_layer):
|
||||
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
|
||||
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
|
||||
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
|
||||
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
|
||||
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
|
||||
canvas_result_local = st_canvas(
|
||||
fill_color="#000000",
|
||||
stroke_width=stroke_width,
|
||||
stroke_color="#000000",
|
||||
background_color="rgba(255, 255, 255, 0)",
|
||||
background_image=white_board,
|
||||
update_streamlit=True,
|
||||
height=512,
|
||||
width=512,
|
||||
drawing_mode="freedraw",
|
||||
key=f"canvas_{painter_tab_id}"
|
||||
)
|
||||
if enable_local_prompt:
|
||||
local_prompts.append(local_prompt)
|
||||
if canvas_result_local.image_data is not None:
|
||||
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
|
||||
else:
|
||||
mask = white_board
|
||||
mask = Image.fromarray(255 - np.array(mask))
|
||||
masks.append(mask)
|
||||
mask_scales.append(mask_scale)
|
||||
|
||||
|
||||
with column_output:
|
||||
run_button = st.button("Generate image", type="primary")
|
||||
@@ -282,6 +313,7 @@ with column_output:
|
||||
progress_bar_st = st.progress(0.0)
|
||||
image = pipeline(
|
||||
prompt, negative_prompt=negative_prompt,
|
||||
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
|
||||
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
|
||||
height=height, width=width,
|
||||
input_image=input_image, denoising_strength=denoising_strength,
|
||||
|
||||
Reference in New Issue
Block a user