update webui

This commit is contained in:
Artiprocher
2026-04-14 15:50:48 +08:00
parent db0f1571b1
commit e2a04139b4
8 changed files with 61 additions and 28 deletions

View File

@@ -74,7 +74,7 @@ class AnimaImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
cfg_scale: float = 4.0, cfg_scale: float = 4.0,
# Image # Image

View File

@@ -75,7 +75,7 @@ class Flux2ImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
cfg_scale: float = 1.0, cfg_scale: float = 1.0,
embedded_guidance: float = 4.0, embedded_guidance: float = 4.0,

View File

@@ -181,7 +181,7 @@ class FluxImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
cfg_scale: float = 1.0, cfg_scale: float = 1.0,
embedded_guidance: float = 3.5, embedded_guidance: float = 3.5,
@@ -199,10 +199,6 @@ class FluxImagePipeline(BasePipeline):
sigma_shift: float = None, sigma_shift: float = None,
# Steps # Steps
num_inference_steps: int = 30, num_inference_steps: int = 30,
# local prompts
multidiffusion_prompts:tuple[str] =(),
multidiffusion_masks:tuple[str]=(),
multidiffusion_scales:tuple[str]=(),
# Kontext # Kontext
kontext_images: Union[list[Image.Image], Image.Image] = None, kontext_images: Union[list[Image.Image], Image.Image] = None,
# ControlNet # ControlNet
@@ -257,7 +253,6 @@ class FluxImagePipeline(BasePipeline):
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps, "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
"kontext_images": kontext_images, "kontext_images": kontext_images,
"controlnet_inputs": controlnet_inputs, "controlnet_inputs": controlnet_inputs,
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale, "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,

View File

@@ -169,7 +169,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
denoising_strength: float = 1.0, denoising_strength: float = 1.0,
# Image-to-video # Image-to-video

View File

@@ -115,7 +115,7 @@ class MovaAudioVideoPipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
# Image-to-video # Image-to-video
input_image: Image.Image = None, input_image: Image.Image = None,

View File

@@ -100,7 +100,7 @@ class QwenImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
cfg_scale: float = 4.0, cfg_scale: float = 4.0,
# Image # Image

View File

@@ -190,7 +190,7 @@ class WanVideoPipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
# Image-to-video # Image-to-video
input_image: Image.Image = None, input_image: Image.Image = None,

View File

@@ -1,7 +1,7 @@
import importlib, inspect, pkgutil, traceback, torch, os, re import importlib, inspect, pkgutil, traceback, torch, os, re, typing
from typing import Union, List, Optional, Tuple, Iterable, Dict from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal
from contextlib import contextmanager from contextlib import contextmanager
from diffsynth.utils.data import VideoData
import streamlit as st import streamlit as st
from diffsynth import ModelConfig from diffsynth import ModelConfig
from diffsynth.diffusion.base_pipeline import ControlNetInput from diffsynth.diffusion.base_pipeline import ControlNetInput
@@ -141,6 +141,18 @@ def draw_multi_images(name="", value=None, disabled=False):
if image is not None: images.append(Image.open(image)) if image is not None: images.append(Image.open(image))
return images return images
def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None):
if kwargs is None:
kwargs = {}
elements = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs)
elements.append(element)
return elements
def draw_controlnet_input(name="", value=None, disabled=False): def draw_controlnet_input(name="", value=None, disabled=False):
with st.container(border=True): with st.container(border=True):
st.markdown(name) st.markdown(name)
@@ -174,7 +186,7 @@ def draw_ui_element(name, dtype, value):
if value is None: if value is None:
with st.container(border=True): with st.container(border=True):
enable = st.checkbox(f"Enable {name}", value=False) enable = st.checkbox(f"Enable {name}", value=False)
ui = draw_ui_element_safely(name, dtype, value, disabled=not enable) ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable)
if enable: if enable:
return ui return ui
else: else:
@@ -182,6 +194,13 @@ def draw_ui_element(name, dtype, value):
else: else:
return draw_ui_element_safely(name, dtype, value) return draw_ui_element_safely(name, dtype, value)
def draw_video(name, value=None, disabled=False):
ui = st.file_uploader(name, type=["mp4"], disabled=disabled)
if ui is not None:
ui = VideoData(ui)
ui = [ui[i] for i in range(len(ui))]
return ui
def draw_ui_element_safely(name, dtype, value, disabled=False): def draw_ui_element_safely(name, dtype, value, disabled=False):
if dtype == torch.dtype: if dtype == torch.dtype:
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16} option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
@@ -190,10 +209,10 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
option_map = {"cuda": "cuda", "cpu": "cpu"} option_map = {"cuda": "cuda", "cpu": "cpu"}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled) ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == bool: elif dtype == bool:
ui = st.checkbox(name, value, disabled=disabled) ui = st.checkbox(name, value=value, disabled=disabled)
elif dtype == ModelConfig: elif dtype == ModelConfig:
ui = draw_single_model_config(name, value, disabled=disabled) ui = draw_single_model_config(name, value=value, disabled=disabled)
elif dtype == list[ModelConfig]: elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]:
if name == "model_configs" and "model_configs_from_example" in st.session_state: if name == "model_configs" and "model_configs_from_example" in st.session_state:
model_configs = st.session_state["model_configs_from_example"] model_configs = st.session_state["model_configs_from_example"]
del st.session_state["model_configs_from_example"] del st.session_state["model_configs_from_example"]
@@ -202,20 +221,39 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
ui = draw_multi_model_config(name, disabled=disabled) ui = draw_multi_model_config(name, disabled=disabled)
elif dtype == str: elif dtype == str:
if "prompt" in name: if "prompt" in name:
ui = st.text_area(name, value, height=3, disabled=disabled) ui = st.text_area(name, value=value, height=3, disabled=disabled)
else: else:
ui = st.text_input(name, value, disabled=disabled) ui = st.text_input(name, value=value, disabled=disabled)
elif dtype == float: elif dtype == float:
ui = st.number_input(name, value, disabled=disabled) ui = st.number_input(name, value=value, disabled=disabled)
elif dtype == int: elif dtype == int:
ui = st.number_input(name, value, step=1, disabled=disabled) ui = st.number_input(name, value=value, step=1, disabled=disabled)
elif dtype == Image.Image: elif dtype == Image.Image:
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled) ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
if ui is not None: ui = Image.open(ui) if ui is not None: ui = Image.open(ui)
elif dtype == List[Image.Image]: elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]:
ui = draw_multi_images(name, value, disabled=disabled) if "video" in name:
elif dtype == List[ControlNetInput]: ui = draw_video(name, value=value, disabled=disabled)
ui = draw_controlnet_inputs(name, value, disabled=disabled) else:
ui = draw_multi_images(name, value=value, disabled=disabled)
elif dtype in [List[ControlNetInput], list[ControlNetInput]]:
ui = draw_controlnet_inputs(name, value=value, disabled=disabled)
elif dtype in [List[str], list[str]]:
ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled)
elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled)
elif dtype in [List[int], list[int]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1})
elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]:
ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled)
elif dtype in [tuple[int, int], Tuple[int, int]]:
with st.container(border=True):
st.markdown(name)
ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled))
elif isinstance(dtype, typing._LiteralGenericAlias):
with st.container(border=True):
st.markdown(f"{name} ({dtype})")
ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden")
elif dtype is None: elif dtype is None:
if name == "progress_bar_cmd": if name == "progress_bar_cmd":
ui = value ui = value
@@ -260,7 +298,7 @@ def launch_webui():
with st.expander("Input", expanded=True): with st.expander("Input", expanded=True):
pipe = st.session_state["pipe"] pipe = st.session_state["pipe"]
input_params = {} input_params = {}
params = parse_params(pipe.__call__) params = parse_params(pipeline_class.__call__)
for param in params: for param in params:
if param["name"] in ["self"]: if param["name"] in ["self"]:
continue continue