mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 02:38:10 +00:00
update UI
This commit is contained in:
@@ -31,4 +31,23 @@ class BasePipeline(torch.nn.Module):
|
|||||||
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
video = vae_output.cpu().permute(1, 2, 0).numpy()
|
||||||
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video]
|
||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
def merge_latents(self, value, latents, masks, scales):
|
||||||
|
height, width = value.shape[-2:]
|
||||||
|
weight = torch.ones_like(value)
|
||||||
|
for latent, mask, scale in zip(latents, masks, scales):
|
||||||
|
mask = self.preprocess_image(mask.resize((height, width))).mean(dim=1, keepdim=True) > 0
|
||||||
|
mask = mask.repeat(1, latent.shape[1], 1, 1)
|
||||||
|
value[mask] += latent[mask] * scale
|
||||||
|
weight[mask] += scale
|
||||||
|
value /= weight
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def control_noise_via_local_prompts(self, prompt_emb_global, prompt_emb_locals, masks, mask_scales, inference_callback):
|
||||||
|
noise_pred_global = inference_callback(prompt_emb_global)
|
||||||
|
noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals]
|
||||||
|
noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales)
|
||||||
|
return noise_pred
|
||||||
|
|
||||||
@@ -209,6 +209,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
|
local_prompts=[],
|
||||||
|
masks=[],
|
||||||
|
mask_scales=[],
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
clip_skip=1,
|
clip_skip=1,
|
||||||
@@ -241,6 +244,7 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
|||||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||||
|
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
|
||||||
|
|
||||||
# Prepare positional id
|
# Prepare positional id
|
||||||
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
|
extra_input = self.prepare_extra_input(latents, tiled, tile_size)
|
||||||
@@ -250,9 +254,9 @@ class HunyuanDiTImagePipeline(BasePipeline):
|
|||||||
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
# Positive side
|
# Positive side
|
||||||
noise_pred_posi = self.dit(
|
inference_callback = lambda prompt_emb_posi: self.dit(latents, timestep=timestep, **prompt_emb_posi, **extra_input)
|
||||||
latents, timestep=timestep, **prompt_emb_posi, **extra_input,
|
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||||
)
|
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
# Negative side
|
# Negative side
|
||||||
noise_pred_nega = self.dit(
|
noise_pred_nega = self.dit(
|
||||||
|
|||||||
@@ -73,6 +73,9 @@ class SD3ImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
|
local_prompts=[],
|
||||||
|
masks=[],
|
||||||
|
mask_scales=[],
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
input_image=None,
|
input_image=None,
|
||||||
@@ -104,15 +107,17 @@ class SD3ImagePipeline(BasePipeline):
|
|||||||
# Encode prompts
|
# Encode prompts
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||||
|
prompt_emb_locals = [self.encode_prompt(prompt_local) for prompt_local in local_prompts]
|
||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
timestep = timestep.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
noise_pred_posi = self.dit(
|
inference_callback = lambda prompt_emb_posi: self.dit(
|
||||||
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
|
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs,
|
||||||
)
|
)
|
||||||
|
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||||
noise_pred_nega = self.dit(
|
noise_pred_nega = self.dit(
|
||||||
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
|
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -90,6 +90,9 @@ class SDImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
|
local_prompts=[],
|
||||||
|
masks=[],
|
||||||
|
mask_scales=[],
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
clip_skip=1,
|
clip_skip=1,
|
||||||
@@ -125,6 +128,7 @@ class SDImagePipeline(BasePipeline):
|
|||||||
# Encode prompts
|
# Encode prompts
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, positive=True)
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, positive=False)
|
||||||
|
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, positive=True) for prompt_local in local_prompts]
|
||||||
|
|
||||||
# IP-Adapter
|
# IP-Adapter
|
||||||
if ipadapter_images is not None:
|
if ipadapter_images is not None:
|
||||||
@@ -147,12 +151,13 @@ class SDImagePipeline(BasePipeline):
|
|||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
timestep = timestep.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
noise_pred_posi = lets_dance(
|
inference_callback = lambda prompt_emb_posi: lets_dance(
|
||||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||||
sample=latents, timestep=timestep,
|
sample=latents, timestep=timestep,
|
||||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||||
noise_pred_nega = lets_dance(
|
noise_pred_nega = lets_dance(
|
||||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||||
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
sample=latents, timestep=timestep, **prompt_emb_nega, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_nega,
|
||||||
|
|||||||
@@ -109,6 +109,9 @@ class SDXLImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
|
local_prompts=[],
|
||||||
|
masks=[],
|
||||||
|
mask_scales=[],
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
clip_skip=1,
|
clip_skip=1,
|
||||||
@@ -146,6 +149,7 @@ class SDXLImagePipeline(BasePipeline):
|
|||||||
# Encode prompts
|
# Encode prompts
|
||||||
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
prompt_emb_posi = self.encode_prompt(prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True)
|
||||||
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
|
prompt_emb_nega = self.encode_prompt(negative_prompt, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=False)
|
||||||
|
prompt_emb_locals = [self.encode_prompt(prompt_local, clip_skip=clip_skip, clip_skip_2=clip_skip_2, positive=True) for prompt_local in local_prompts]
|
||||||
|
|
||||||
# IP-Adapter
|
# IP-Adapter
|
||||||
if ipadapter_images is not None:
|
if ipadapter_images is not None:
|
||||||
@@ -175,12 +179,14 @@ class SDXLImagePipeline(BasePipeline):
|
|||||||
timestep = timestep.unsqueeze(0).to(self.device)
|
timestep = timestep.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
noise_pred_posi = lets_dance_xl(
|
inference_callback = lambda prompt_emb_posi: lets_dance_xl(
|
||||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||||
sample=latents, timestep=timestep, **extra_input,
|
sample=latents, timestep=timestep, **extra_input,
|
||||||
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
**prompt_emb_posi, **controlnet_kwargs, **tiler_kwargs, **ipadapter_kwargs_list_posi,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
noise_pred_posi = self.control_noise_via_local_prompts(prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback)
|
||||||
|
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_nega = lets_dance_xl(
|
noise_pred_nega = lets_dance_xl(
|
||||||
self.unet, motion_modules=None, controlnet=self.controlnet,
|
self.unet, motion_modules=None, controlnet=self.controlnet,
|
||||||
|
|||||||
@@ -255,6 +255,37 @@ with column_input:
|
|||||||
key="canvas"
|
key="canvas"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
|
||||||
|
local_prompts, masks, mask_scales = [], [], []
|
||||||
|
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
|
||||||
|
for painter_tab_id in range(num_painter_layer):
|
||||||
|
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
|
||||||
|
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
|
||||||
|
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
|
||||||
|
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
|
||||||
|
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
|
||||||
|
canvas_result_local = st_canvas(
|
||||||
|
fill_color="#000000",
|
||||||
|
stroke_width=stroke_width,
|
||||||
|
stroke_color="#000000",
|
||||||
|
background_color="rgba(255, 255, 255, 0)",
|
||||||
|
background_image=white_board,
|
||||||
|
update_streamlit=True,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
drawing_mode="freedraw",
|
||||||
|
key=f"canvas_{painter_tab_id}"
|
||||||
|
)
|
||||||
|
if enable_local_prompt:
|
||||||
|
local_prompts.append(local_prompt)
|
||||||
|
if canvas_result_local.image_data is not None:
|
||||||
|
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
|
||||||
|
else:
|
||||||
|
mask = white_board
|
||||||
|
mask = Image.fromarray(255 - np.array(mask))
|
||||||
|
masks.append(mask)
|
||||||
|
mask_scales.append(mask_scale)
|
||||||
|
|
||||||
|
|
||||||
with column_output:
|
with column_output:
|
||||||
run_button = st.button("Generate image", type="primary")
|
run_button = st.button("Generate image", type="primary")
|
||||||
@@ -282,6 +313,7 @@ with column_output:
|
|||||||
progress_bar_st = st.progress(0.0)
|
progress_bar_st = st.progress(0.0)
|
||||||
image = pipeline(
|
image = pipeline(
|
||||||
prompt, negative_prompt=negative_prompt,
|
prompt, negative_prompt=negative_prompt,
|
||||||
|
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
|
||||||
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
|
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
|
||||||
height=height, width=width,
|
height=height, width=width,
|
||||||
input_image=input_image, denoising_strength=denoising_strength,
|
input_image=input_image, denoising_strength=denoising_strength,
|
||||||
|
|||||||
Reference in New Issue
Block a user