Compare commits

..

2 Commits

Author SHA1 Message Date
Artiprocher
e2a04139b4 update webui 2026-04-14 15:50:48 +08:00
lzws
db0f1571b1 fix pipeline args 2026-04-14 15:17:19 +08:00
8 changed files with 61 additions and 28 deletions

View File

@@ -74,7 +74,7 @@ class AnimaImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Image

View File

@@ -75,7 +75,7 @@ class Flux2ImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
cfg_scale: float = 1.0,
embedded_guidance: float = 4.0,

View File

@@ -181,7 +181,7 @@ class FluxImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
cfg_scale: float = 1.0,
embedded_guidance: float = 3.5,
@@ -199,10 +199,6 @@ class FluxImagePipeline(BasePipeline):
sigma_shift: float = None,
# Steps
num_inference_steps: int = 30,
# local prompts
multidiffusion_prompts:tuple[str] =(),
multidiffusion_masks:tuple[str]=(),
multidiffusion_scales:tuple[str]=(),
# Kontext
kontext_images: Union[list[Image.Image], Image.Image] = None,
# ControlNet
@@ -257,7 +253,6 @@ class FluxImagePipeline(BasePipeline):
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
"kontext_images": kontext_images,
"controlnet_inputs": controlnet_inputs,
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,

View File

@@ -169,7 +169,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
denoising_strength: float = 1.0,
# Image-to-video

View File

@@ -115,7 +115,7 @@ class MovaAudioVideoPipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
# Image-to-video
input_image: Image.Image = None,

View File

@@ -100,7 +100,7 @@ class QwenImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Image

View File

@@ -190,7 +190,7 @@ class WanVideoPipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
# Image-to-video
input_image: Image.Image = None,

View File

@@ -1,7 +1,7 @@
import importlib, inspect, pkgutil, traceback, torch, os, re
from typing import Union, List, Optional, Tuple, Iterable, Dict
import importlib, inspect, pkgutil, traceback, torch, os, re, typing
from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal
from contextlib import contextmanager
from diffsynth.utils.data import VideoData
import streamlit as st
from diffsynth import ModelConfig
from diffsynth.diffusion.base_pipeline import ControlNetInput
@@ -141,6 +141,18 @@ def draw_multi_images(name="", value=None, disabled=False):
if image is not None: images.append(Image.open(image))
return images
def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None):
if kwargs is None:
kwargs = {}
elements = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs)
elements.append(element)
return elements
def draw_controlnet_input(name="", value=None, disabled=False):
with st.container(border=True):
st.markdown(name)
@@ -174,7 +186,7 @@ def draw_ui_element(name, dtype, value):
if value is None:
with st.container(border=True):
enable = st.checkbox(f"Enable {name}", value=False)
ui = draw_ui_element_safely(name, dtype, value, disabled=not enable)
ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable)
if enable:
return ui
else:
@@ -182,6 +194,13 @@ def draw_ui_element(name, dtype, value):
else:
return draw_ui_element_safely(name, dtype, value)
def draw_video(name, value=None, disabled=False):
ui = st.file_uploader(name, type=["mp4"], disabled=disabled)
if ui is not None:
ui = VideoData(ui)
ui = [ui[i] for i in range(len(ui))]
return ui
def draw_ui_element_safely(name, dtype, value, disabled=False):
if dtype == torch.dtype:
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
@@ -190,10 +209,10 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
option_map = {"cuda": "cuda", "cpu": "cpu"}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == bool:
ui = st.checkbox(name, value, disabled=disabled)
ui = st.checkbox(name, value=value, disabled=disabled)
elif dtype == ModelConfig:
ui = draw_single_model_config(name, value, disabled=disabled)
elif dtype == list[ModelConfig]:
ui = draw_single_model_config(name, value=value, disabled=disabled)
elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]:
if name == "model_configs" and "model_configs_from_example" in st.session_state:
model_configs = st.session_state["model_configs_from_example"]
del st.session_state["model_configs_from_example"]
@@ -202,20 +221,39 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
ui = draw_multi_model_config(name, disabled=disabled)
elif dtype == str:
if "prompt" in name:
ui = st.text_area(name, value, height=3, disabled=disabled)
ui = st.text_area(name, value=value, height=3, disabled=disabled)
else:
ui = st.text_input(name, value, disabled=disabled)
ui = st.text_input(name, value=value, disabled=disabled)
elif dtype == float:
ui = st.number_input(name, value, disabled=disabled)
ui = st.number_input(name, value=value, disabled=disabled)
elif dtype == int:
ui = st.number_input(name, value, step=1, disabled=disabled)
ui = st.number_input(name, value=value, step=1, disabled=disabled)
elif dtype == Image.Image:
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
if ui is not None: ui = Image.open(ui)
elif dtype == List[Image.Image]:
ui = draw_multi_images(name, value, disabled=disabled)
elif dtype == List[ControlNetInput]:
ui = draw_controlnet_inputs(name, value, disabled=disabled)
elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]:
if "video" in name:
ui = draw_video(name, value=value, disabled=disabled)
else:
ui = draw_multi_images(name, value=value, disabled=disabled)
elif dtype in [List[ControlNetInput], list[ControlNetInput]]:
ui = draw_controlnet_inputs(name, value=value, disabled=disabled)
elif dtype in [List[str], list[str]]:
ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled)
elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled)
elif dtype in [List[int], list[int]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1})
elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]:
ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled)
elif dtype in [tuple[int, int], Tuple[int, int]]:
with st.container(border=True):
st.markdown(name)
ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled))
elif isinstance(dtype, typing._LiteralGenericAlias):
with st.container(border=True):
st.markdown(f"{name} ({dtype})")
ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden")
elif dtype is None:
if name == "progress_bar_cmd":
ui = value
@@ -260,7 +298,7 @@ def launch_webui():
with st.expander("Input", expanded=True):
pipe = st.session_state["pipe"]
input_params = {}
params = parse_params(pipe.__call__)
params = parse_params(pipeline_class.__call__)
for param in params:
if param["name"] in ["self"]:
continue