mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-14 21:58:17 +00:00
update webui
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import importlib, inspect, pkgutil, traceback, torch, os, re
|
||||
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
||||
import importlib, inspect, pkgutil, traceback, torch, os, re, typing
|
||||
from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal
|
||||
from contextlib import contextmanager
|
||||
|
||||
from diffsynth.utils.data import VideoData
|
||||
import streamlit as st
|
||||
from diffsynth import ModelConfig
|
||||
from diffsynth.diffusion.base_pipeline import ControlNetInput
|
||||
@@ -141,6 +141,18 @@ def draw_multi_images(name="", value=None, disabled=False):
|
||||
if image is not None: images.append(Image.open(image))
|
||||
return images
|
||||
|
||||
def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
elements = []
|
||||
with st.container(border=True):
|
||||
st.markdown(name)
|
||||
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
|
||||
for i in range(num):
|
||||
element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs)
|
||||
elements.append(element)
|
||||
return elements
|
||||
|
||||
def draw_controlnet_input(name="", value=None, disabled=False):
|
||||
with st.container(border=True):
|
||||
st.markdown(name)
|
||||
@@ -174,7 +186,7 @@ def draw_ui_element(name, dtype, value):
|
||||
if value is None:
|
||||
with st.container(border=True):
|
||||
enable = st.checkbox(f"Enable {name}", value=False)
|
||||
ui = draw_ui_element_safely(name, dtype, value, disabled=not enable)
|
||||
ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable)
|
||||
if enable:
|
||||
return ui
|
||||
else:
|
||||
@@ -182,6 +194,13 @@ def draw_ui_element(name, dtype, value):
|
||||
else:
|
||||
return draw_ui_element_safely(name, dtype, value)
|
||||
|
||||
def draw_video(name, value=None, disabled=False):
|
||||
ui = st.file_uploader(name, type=["mp4"], disabled=disabled)
|
||||
if ui is not None:
|
||||
ui = VideoData(ui)
|
||||
ui = [ui[i] for i in range(len(ui))]
|
||||
return ui
|
||||
|
||||
def draw_ui_element_safely(name, dtype, value, disabled=False):
|
||||
if dtype == torch.dtype:
|
||||
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
@@ -190,10 +209,10 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
|
||||
option_map = {"cuda": "cuda", "cpu": "cpu"}
|
||||
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
|
||||
elif dtype == bool:
|
||||
ui = st.checkbox(name, value, disabled=disabled)
|
||||
ui = st.checkbox(name, value=value, disabled=disabled)
|
||||
elif dtype == ModelConfig:
|
||||
ui = draw_single_model_config(name, value, disabled=disabled)
|
||||
elif dtype == list[ModelConfig]:
|
||||
ui = draw_single_model_config(name, value=value, disabled=disabled)
|
||||
elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]:
|
||||
if name == "model_configs" and "model_configs_from_example" in st.session_state:
|
||||
model_configs = st.session_state["model_configs_from_example"]
|
||||
del st.session_state["model_configs_from_example"]
|
||||
@@ -202,20 +221,39 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
|
||||
ui = draw_multi_model_config(name, disabled=disabled)
|
||||
elif dtype == str:
|
||||
if "prompt" in name:
|
||||
ui = st.text_area(name, value, height=3, disabled=disabled)
|
||||
ui = st.text_area(name, value=value, height=3, disabled=disabled)
|
||||
else:
|
||||
ui = st.text_input(name, value, disabled=disabled)
|
||||
ui = st.text_input(name, value=value, disabled=disabled)
|
||||
elif dtype == float:
|
||||
ui = st.number_input(name, value, disabled=disabled)
|
||||
ui = st.number_input(name, value=value, disabled=disabled)
|
||||
elif dtype == int:
|
||||
ui = st.number_input(name, value, step=1, disabled=disabled)
|
||||
ui = st.number_input(name, value=value, step=1, disabled=disabled)
|
||||
elif dtype == Image.Image:
|
||||
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
|
||||
if ui is not None: ui = Image.open(ui)
|
||||
elif dtype == List[Image.Image]:
|
||||
ui = draw_multi_images(name, value, disabled=disabled)
|
||||
elif dtype == List[ControlNetInput]:
|
||||
ui = draw_controlnet_inputs(name, value, disabled=disabled)
|
||||
elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]:
|
||||
if "video" in name:
|
||||
ui = draw_video(name, value=value, disabled=disabled)
|
||||
else:
|
||||
ui = draw_multi_images(name, value=value, disabled=disabled)
|
||||
elif dtype in [List[ControlNetInput], list[ControlNetInput]]:
|
||||
ui = draw_controlnet_inputs(name, value=value, disabled=disabled)
|
||||
elif dtype in [List[str], list[str]]:
|
||||
ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled)
|
||||
elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]:
|
||||
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled)
|
||||
elif dtype in [List[int], list[int]]:
|
||||
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1})
|
||||
elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]:
|
||||
ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled)
|
||||
elif dtype in [tuple[int, int], Tuple[int, int]]:
|
||||
with st.container(border=True):
|
||||
st.markdown(name)
|
||||
ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled))
|
||||
elif isinstance(dtype, typing._LiteralGenericAlias):
|
||||
with st.container(border=True):
|
||||
st.markdown(f"{name} ({dtype})")
|
||||
ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden")
|
||||
elif dtype is None:
|
||||
if name == "progress_bar_cmd":
|
||||
ui = value
|
||||
@@ -260,7 +298,7 @@ def launch_webui():
|
||||
with st.expander("Input", expanded=True):
|
||||
pipe = st.session_state["pipe"]
|
||||
input_params = {}
|
||||
params = parse_params(pipe.__call__)
|
||||
params = parse_params(pipeline_class.__call__)
|
||||
for param in params:
|
||||
if param["name"] in ["self"]:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user