mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-14 21:58:17 +00:00
update webui
This commit is contained in:
@@ -74,7 +74,7 @@ class AnimaImagePipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
prompt: str = "",
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 4.0,
|
||||
# Image
|
||||
|
||||
@@ -75,7 +75,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
prompt: str = "",
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
embedded_guidance: float = 4.0,
|
||||
|
||||
@@ -181,7 +181,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
prompt: str = "",
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
embedded_guidance: float = 3.5,
|
||||
@@ -199,10 +199,6 @@ class FluxImagePipeline(BasePipeline):
|
||||
sigma_shift: float = None,
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# local prompts
|
||||
multidiffusion_prompts:tuple[str] =(),
|
||||
multidiffusion_masks:tuple[str]=(),
|
||||
multidiffusion_scales:tuple[str]=(),
|
||||
# Kontext
|
||||
kontext_images: Union[list[Image.Image], Image.Image] = None,
|
||||
# ControlNet
|
||||
@@ -257,7 +253,6 @@ class FluxImagePipeline(BasePipeline):
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
|
||||
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
|
||||
"kontext_images": kontext_images,
|
||||
"controlnet_inputs": controlnet_inputs,
|
||||
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,
|
||||
|
||||
@@ -169,7 +169,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
prompt: str = "",
|
||||
negative_prompt: str = "",
|
||||
denoising_strength: float = 1.0,
|
||||
# Image-to-video
|
||||
|
||||
@@ -115,7 +115,7 @@ class MovaAudioVideoPipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
prompt: str = "",
|
||||
negative_prompt: str = "",
|
||||
# Image-to-video
|
||||
input_image: Image.Image = None,
|
||||
|
||||
@@ -100,7 +100,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
prompt: str = "",
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 4.0,
|
||||
# Image
|
||||
|
||||
@@ -190,7 +190,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
prompt: str = "",
|
||||
negative_prompt: str = "",
|
||||
# Image-to-video
|
||||
input_image: Image.Image = None,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import importlib, inspect, pkgutil, traceback, torch, os, re
|
||||
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
||||
import importlib, inspect, pkgutil, traceback, torch, os, re, typing
|
||||
from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal
|
||||
from contextlib import contextmanager
|
||||
|
||||
from diffsynth.utils.data import VideoData
|
||||
import streamlit as st
|
||||
from diffsynth import ModelConfig
|
||||
from diffsynth.diffusion.base_pipeline import ControlNetInput
|
||||
@@ -141,6 +141,18 @@ def draw_multi_images(name="", value=None, disabled=False):
|
||||
if image is not None: images.append(Image.open(image))
|
||||
return images
|
||||
|
||||
def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
elements = []
|
||||
with st.container(border=True):
|
||||
st.markdown(name)
|
||||
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
|
||||
for i in range(num):
|
||||
element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs)
|
||||
elements.append(element)
|
||||
return elements
|
||||
|
||||
def draw_controlnet_input(name="", value=None, disabled=False):
|
||||
with st.container(border=True):
|
||||
st.markdown(name)
|
||||
@@ -174,7 +186,7 @@ def draw_ui_element(name, dtype, value):
|
||||
if value is None:
|
||||
with st.container(border=True):
|
||||
enable = st.checkbox(f"Enable {name}", value=False)
|
||||
ui = draw_ui_element_safely(name, dtype, value, disabled=not enable)
|
||||
ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable)
|
||||
if enable:
|
||||
return ui
|
||||
else:
|
||||
@@ -182,6 +194,13 @@ def draw_ui_element(name, dtype, value):
|
||||
else:
|
||||
return draw_ui_element_safely(name, dtype, value)
|
||||
|
||||
def draw_video(name, value=None, disabled=False):
|
||||
ui = st.file_uploader(name, type=["mp4"], disabled=disabled)
|
||||
if ui is not None:
|
||||
ui = VideoData(ui)
|
||||
ui = [ui[i] for i in range(len(ui))]
|
||||
return ui
|
||||
|
||||
def draw_ui_element_safely(name, dtype, value, disabled=False):
|
||||
if dtype == torch.dtype:
|
||||
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
@@ -190,10 +209,10 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
|
||||
option_map = {"cuda": "cuda", "cpu": "cpu"}
|
||||
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
|
||||
elif dtype == bool:
|
||||
ui = st.checkbox(name, value, disabled=disabled)
|
||||
ui = st.checkbox(name, value=value, disabled=disabled)
|
||||
elif dtype == ModelConfig:
|
||||
ui = draw_single_model_config(name, value, disabled=disabled)
|
||||
elif dtype == list[ModelConfig]:
|
||||
ui = draw_single_model_config(name, value=value, disabled=disabled)
|
||||
elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]:
|
||||
if name == "model_configs" and "model_configs_from_example" in st.session_state:
|
||||
model_configs = st.session_state["model_configs_from_example"]
|
||||
del st.session_state["model_configs_from_example"]
|
||||
@@ -202,20 +221,39 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
|
||||
ui = draw_multi_model_config(name, disabled=disabled)
|
||||
elif dtype == str:
|
||||
if "prompt" in name:
|
||||
ui = st.text_area(name, value, height=3, disabled=disabled)
|
||||
ui = st.text_area(name, value=value, height=3, disabled=disabled)
|
||||
else:
|
||||
ui = st.text_input(name, value, disabled=disabled)
|
||||
ui = st.text_input(name, value=value, disabled=disabled)
|
||||
elif dtype == float:
|
||||
ui = st.number_input(name, value, disabled=disabled)
|
||||
ui = st.number_input(name, value=value, disabled=disabled)
|
||||
elif dtype == int:
|
||||
ui = st.number_input(name, value, step=1, disabled=disabled)
|
||||
ui = st.number_input(name, value=value, step=1, disabled=disabled)
|
||||
elif dtype == Image.Image:
|
||||
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
|
||||
if ui is not None: ui = Image.open(ui)
|
||||
elif dtype == List[Image.Image]:
|
||||
ui = draw_multi_images(name, value, disabled=disabled)
|
||||
elif dtype == List[ControlNetInput]:
|
||||
ui = draw_controlnet_inputs(name, value, disabled=disabled)
|
||||
elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]:
|
||||
if "video" in name:
|
||||
ui = draw_video(name, value=value, disabled=disabled)
|
||||
else:
|
||||
ui = draw_multi_images(name, value=value, disabled=disabled)
|
||||
elif dtype in [List[ControlNetInput], list[ControlNetInput]]:
|
||||
ui = draw_controlnet_inputs(name, value=value, disabled=disabled)
|
||||
elif dtype in [List[str], list[str]]:
|
||||
ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled)
|
||||
elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]:
|
||||
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled)
|
||||
elif dtype in [List[int], list[int]]:
|
||||
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1})
|
||||
elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]:
|
||||
ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled)
|
||||
elif dtype in [tuple[int, int], Tuple[int, int]]:
|
||||
with st.container(border=True):
|
||||
st.markdown(name)
|
||||
ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled))
|
||||
elif isinstance(dtype, typing._LiteralGenericAlias):
|
||||
with st.container(border=True):
|
||||
st.markdown(f"{name} ({dtype})")
|
||||
ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden")
|
||||
elif dtype is None:
|
||||
if name == "progress_bar_cmd":
|
||||
ui = value
|
||||
@@ -260,7 +298,7 @@ def launch_webui():
|
||||
with st.expander("Input", expanded=True):
|
||||
pipe = st.session_state["pipe"]
|
||||
input_params = {}
|
||||
params = parse_params(pipe.__call__)
|
||||
params = parse_params(pipeline_class.__call__)
|
||||
for param in params:
|
||||
if param["name"] in ["self"]:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user