update webui

This commit is contained in:
Artiprocher
2026-04-14 15:50:48 +08:00
parent db0f1571b1
commit e2a04139b4
8 changed files with 61 additions and 28 deletions

View File

@@ -1,7 +1,7 @@
import importlib, inspect, pkgutil, traceback, torch, os, re
from typing import Union, List, Optional, Tuple, Iterable, Dict
import importlib, inspect, pkgutil, traceback, torch, os, re, typing
from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal
from contextlib import contextmanager
from diffsynth.utils.data import VideoData
import streamlit as st
from diffsynth import ModelConfig
from diffsynth.diffusion.base_pipeline import ControlNetInput
@@ -141,6 +141,18 @@ def draw_multi_images(name="", value=None, disabled=False):
if image is not None: images.append(Image.open(image))
return images
def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None):
if kwargs is None:
kwargs = {}
elements = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs)
elements.append(element)
return elements
def draw_controlnet_input(name="", value=None, disabled=False):
with st.container(border=True):
st.markdown(name)
@@ -174,7 +186,7 @@ def draw_ui_element(name, dtype, value):
if value is None:
with st.container(border=True):
enable = st.checkbox(f"Enable {name}", value=False)
ui = draw_ui_element_safely(name, dtype, value, disabled=not enable)
ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable)
if enable:
return ui
else:
@@ -182,6 +194,13 @@ def draw_ui_element(name, dtype, value):
else:
return draw_ui_element_safely(name, dtype, value)
def draw_video(name, value=None, disabled=False):
ui = st.file_uploader(name, type=["mp4"], disabled=disabled)
if ui is not None:
ui = VideoData(ui)
ui = [ui[i] for i in range(len(ui))]
return ui
def draw_ui_element_safely(name, dtype, value, disabled=False):
if dtype == torch.dtype:
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
@@ -190,10 +209,10 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
option_map = {"cuda": "cuda", "cpu": "cpu"}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == bool:
ui = st.checkbox(name, value, disabled=disabled)
ui = st.checkbox(name, value=value, disabled=disabled)
elif dtype == ModelConfig:
ui = draw_single_model_config(name, value, disabled=disabled)
elif dtype == list[ModelConfig]:
ui = draw_single_model_config(name, value=value, disabled=disabled)
elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]:
if name == "model_configs" and "model_configs_from_example" in st.session_state:
model_configs = st.session_state["model_configs_from_example"]
del st.session_state["model_configs_from_example"]
@@ -202,20 +221,39 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
ui = draw_multi_model_config(name, disabled=disabled)
elif dtype == str:
if "prompt" in name:
ui = st.text_area(name, value, height=3, disabled=disabled)
ui = st.text_area(name, value=value, height=3, disabled=disabled)
else:
ui = st.text_input(name, value, disabled=disabled)
ui = st.text_input(name, value=value, disabled=disabled)
elif dtype == float:
ui = st.number_input(name, value, disabled=disabled)
ui = st.number_input(name, value=value, disabled=disabled)
elif dtype == int:
ui = st.number_input(name, value, step=1, disabled=disabled)
ui = st.number_input(name, value=value, step=1, disabled=disabled)
elif dtype == Image.Image:
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
if ui is not None: ui = Image.open(ui)
elif dtype == List[Image.Image]:
ui = draw_multi_images(name, value, disabled=disabled)
elif dtype == List[ControlNetInput]:
ui = draw_controlnet_inputs(name, value, disabled=disabled)
elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]:
if "video" in name:
ui = draw_video(name, value=value, disabled=disabled)
else:
ui = draw_multi_images(name, value=value, disabled=disabled)
elif dtype in [List[ControlNetInput], list[ControlNetInput]]:
ui = draw_controlnet_inputs(name, value=value, disabled=disabled)
elif dtype in [List[str], list[str]]:
ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled)
elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled)
elif dtype in [List[int], list[int]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1})
elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]:
ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled)
elif dtype in [tuple[int, int], Tuple[int, int]]:
with st.container(border=True):
st.markdown(name)
ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled))
elif isinstance(dtype, typing._LiteralGenericAlias):
with st.container(border=True):
st.markdown(f"{name} ({dtype})")
ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden")
elif dtype is None:
if name == "progress_bar_cmd":
ui = value
@@ -260,7 +298,7 @@ def launch_webui():
with st.expander("Input", expanded=True):
pipe = st.session_state["pipe"]
input_params = {}
params = parse_params(pipe.__call__)
params = parse_params(pipeline_class.__call__)
for param in params:
if param["name"] in ["self"]:
continue