Compare commits

...

3 Commits

Author SHA1 Message Date
Artiprocher
e2a04139b4 update webui 2026-04-14 15:50:48 +08:00
lzws
db0f1571b1 fix pipeline args 2026-04-14 15:17:19 +08:00
Artiprocher
224060c2a0 add webui 2026-04-13 10:55:51 +08:00
11 changed files with 434 additions and 118 deletions

View File

@@ -1,4 +1,4 @@
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessor
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch
@@ -40,7 +40,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
value_bias = False
)
super().__init__(config)
self.processor = DINOv3ViTImageProcessorFast(
self.processor = DINOv3ViTImageProcessor(
crop_size = None,
data_format = "channels_first",
default_to_square = True,
@@ -56,7 +56,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
0.456,
0.406
],
image_processor_type = "DINOv3ViTImageProcessorFast",
image_processor_type = "DINOv3ViTImageProcessor",
image_std = [
0.229,
0.224,

View File

@@ -1,5 +1,5 @@
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessor
import torch
from diffsynth.core.device.npu_compatible_device import get_device_type
@@ -90,7 +90,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
transformers_version = "4.57.1"
)
super().__init__(config)
self.processor = Siglip2ImageProcessorFast(
self.processor = Siglip2ImageProcessor(
**{
"data_format": "channels_first",
"default_to_square": True,
@@ -106,7 +106,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
0.5,
0.5
],
"image_processor_type": "Siglip2ImageProcessorFast",
"image_processor_type": "Siglip2ImageProcessor",
"image_std": [
0.5,
0.5,

View File

@@ -74,7 +74,7 @@ class AnimaImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Image

View File

@@ -75,7 +75,7 @@ class Flux2ImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
cfg_scale: float = 1.0,
embedded_guidance: float = 4.0,
@@ -83,7 +83,7 @@ class Flux2ImagePipeline(BasePipeline):
input_image: Image.Image = None,
denoising_strength: float = 1.0,
# Edit
edit_image: Union[Image.Image, List[Image.Image]] = None,
edit_image: List[Image.Image] = None,
edit_image_auto_resize: bool = True,
# Shape
height: int = 1024,

View File

@@ -181,7 +181,7 @@ class FluxImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
cfg_scale: float = 1.0,
embedded_guidance: float = 3.5,
@@ -199,10 +199,6 @@ class FluxImagePipeline(BasePipeline):
sigma_shift: float = None,
# Steps
num_inference_steps: int = 30,
# local prompts
multidiffusion_prompts=(),
multidiffusion_masks=(),
multidiffusion_scales=(),
# Kontext
kontext_images: Union[list[Image.Image], Image.Image] = None,
# ControlNet
@@ -257,7 +253,6 @@ class FluxImagePipeline(BasePipeline):
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
"kontext_images": kontext_images,
"controlnet_inputs": controlnet_inputs,
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,

View File

@@ -169,46 +169,46 @@ class LTX2AudioVideoPipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
prompt: str = "",
negative_prompt: str = "",
denoising_strength: float = 1.0,
# Image-to-video
input_images: Optional[list[Image.Image]] = None,
input_images_indexes: Optional[list[int]] = [0],
input_images_strength: Optional[float] = 1.0,
input_images: list[Image.Image] = None,
input_images_indexes: list[int] = [0],
input_images_strength: float = 1.0,
# In-Context Video Control
in_context_videos: Optional[list[list[Image.Image]]] = None,
in_context_downsample_factor: Optional[int] = 2,
in_context_videos: list[list[Image.Image]] = None,
in_context_downsample_factor: int = 2,
# Video-to-video
retake_video: Optional[list[Image.Image]] = None,
retake_video_regions: Optional[list[tuple[float, float]]] = None,
retake_video: list[Image.Image] = None,
retake_video_regions: list[tuple[float, float]] = None,
# Audio-to-video
retake_audio: Optional[torch.Tensor] = None,
audio_sample_rate: Optional[int] = 48000,
retake_audio_regions: Optional[list[tuple[float, float]]] = None,
retake_audio: torch.Tensor = None,
audio_sample_rate: int = 48000,
retake_audio_regions: list[tuple[float, float]] = None,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
seed: int = None,
rand_device: str = "cpu",
# Shape
height: Optional[int] = 512,
width: Optional[int] = 768,
num_frames: Optional[int] = 121,
frame_rate: Optional[int] = 24,
height: int = 512,
width: int = 768,
num_frames: int = 121,
frame_rate: int = 24,
# Classifier-free guidance
cfg_scale: Optional[float] = 3.0,
cfg_scale: float = 3.0,
# Scheduler
num_inference_steps: Optional[int] = 30,
num_inference_steps: int = 30,
# VAE tiling
tiled: Optional[bool] = True,
tile_size_in_pixels: Optional[int] = 512,
tile_overlap_in_pixels: Optional[int] = 128,
tile_size_in_frames: Optional[int] = 128,
tile_overlap_in_frames: Optional[int] = 24,
tiled: bool = True,
tile_size_in_pixels: int = 512,
tile_overlap_in_pixels: int = 128,
tile_size_in_frames: int = 128,
tile_overlap_in_frames: int = 24,
# Special Pipelines
use_two_stage_pipeline: Optional[bool] = False,
stage2_spatial_upsample_factor: Optional[int] = 2,
clear_lora_before_state_two: Optional[bool] = False,
use_distilled_pipeline: Optional[bool] = False,
use_two_stage_pipeline: bool = False,
stage2_spatial_upsample_factor: int = 2,
clear_lora_before_state_two: bool = False,
use_distilled_pipeline: bool = False,
# progress_bar
progress_bar_cmd=tqdm,
):

View File

@@ -115,33 +115,33 @@ class MovaAudioVideoPipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
prompt: str = "",
negative_prompt: str = "",
# Image-to-video
input_image: Optional[Image.Image] = None,
input_image: Image.Image = None,
# First-last-frame-to-video
end_image: Optional[Image.Image] = None,
end_image: Image.Image = None,
# Video-to-video
denoising_strength: Optional[float] = 1.0,
denoising_strength: float = 1.0,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
seed: int = None,
rand_device: str = "cpu",
# Shape
height: Optional[int] = 352,
width: Optional[int] = 640,
num_frames: Optional[int] = 81,
frame_rate: Optional[int] = 24,
height: int = 352,
width: int = 640,
num_frames: int = 81,
frame_rate: int = 24,
# Classifier-free guidance
cfg_scale: Optional[float] = 5.0,
cfg_scale: float = 5.0,
# Boundary
switch_DiT_boundary: Optional[float] = 0.9,
switch_DiT_boundary: float = 0.9,
# Scheduler
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
num_inference_steps: int = 50,
sigma_shift: float = 5.0,
# VAE tiling
tiled: Optional[bool] = True,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
tiled: bool = True,
tile_size: tuple[int, int] = (30, 52),
tile_stride: tuple[int, int] = (15, 26),
# progress_bar
progress_bar_cmd=tqdm,
):

View File

@@ -100,7 +100,7 @@ class QwenImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
cfg_scale: float = 4.0,
# Image

View File

@@ -190,82 +190,82 @@ class WanVideoPipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: Optional[str] = "",
prompt: str = "",
negative_prompt: str = "",
# Image-to-video
input_image: Optional[Image.Image] = None,
input_image: Image.Image = None,
# First-last-frame-to-video
end_image: Optional[Image.Image] = None,
end_image: Image.Image = None,
# Video-to-video
input_video: Optional[list[Image.Image]] = None,
denoising_strength: Optional[float] = 1.0,
input_video: list[Image.Image] = None,
denoising_strength: float = 1.0,
# Speech-to-video
input_audio: Optional[np.array] = None,
audio_embeds: Optional[torch.Tensor] = None,
audio_sample_rate: Optional[int] = 16000,
s2v_pose_video: Optional[list[Image.Image]] = None,
s2v_pose_latents: Optional[torch.Tensor] = None,
motion_video: Optional[list[Image.Image]] = None,
input_audio: np.array = None,
audio_embeds: torch.Tensor = None,
audio_sample_rate: int = 16000,
s2v_pose_video: list[Image.Image] = None,
s2v_pose_latents: torch.Tensor = None,
motion_video: list[Image.Image] = None,
# ControlNet
control_video: Optional[list[Image.Image]] = None,
reference_image: Optional[Image.Image] = None,
control_video: list[Image.Image] = None,
reference_image: Image.Image = None,
# Camera control
camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None,
camera_control_speed: Optional[float] = 1/54,
camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
camera_control_direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"] = None,
camera_control_speed: float = 1/54,
camera_control_origin: tuple = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
# VACE
vace_video: Optional[list[Image.Image]] = None,
vace_video_mask: Optional[Image.Image] = None,
vace_reference_image: Optional[Image.Image] = None,
vace_scale: Optional[float] = 1.0,
vace_video: list[Image.Image] = None,
vace_video_mask: Image.Image = None,
vace_reference_image: Image.Image = None,
vace_scale: float = 1.0,
# Animate
animate_pose_video: Optional[list[Image.Image]] = None,
animate_face_video: Optional[list[Image.Image]] = None,
animate_inpaint_video: Optional[list[Image.Image]] = None,
animate_mask_video: Optional[list[Image.Image]] = None,
animate_pose_video: list[Image.Image] = None,
animate_face_video: list[Image.Image] = None,
animate_inpaint_video: list[Image.Image] = None,
animate_mask_video: list[Image.Image] = None,
# VAP
vap_video: Optional[list[Image.Image]] = None,
vap_prompt: Optional[str] = " ",
negative_vap_prompt: Optional[str] = " ",
vap_video: list[Image.Image] = None,
vap_prompt: str = " ",
negative_vap_prompt: str = " ",
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
seed: int = None,
rand_device: str = "cpu",
# Shape
height: Optional[int] = 480,
width: Optional[int] = 832,
num_frames=81,
height: int = 480,
width: int = 832,
num_frames: int = 81,
# Classifier-free guidance
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
cfg_scale: float = 5.0,
cfg_merge: bool = False,
# Boundary
switch_DiT_boundary: Optional[float] = 0.875,
switch_DiT_boundary: float = 0.875,
# Scheduler
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
num_inference_steps: int = 50,
sigma_shift: float = 5.0,
# Speed control
motion_bucket_id: Optional[int] = None,
motion_bucket_id: int = None,
# LongCat-Video
longcat_video: Optional[list[Image.Image]] = None,
longcat_video: list[Image.Image] = None,
# VAE tiling
tiled: Optional[bool] = True,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
tiled: bool = True,
tile_size: tuple[int, int] = (30, 52),
tile_stride: tuple[int, int] = (15, 26),
# Sliding window
sliding_window_size: Optional[int] = None,
sliding_window_stride: Optional[int] = None,
sliding_window_size: int = None,
sliding_window_stride: int = None,
# Teacache
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
tea_cache_l1_thresh: float = None,
tea_cache_model_id: str = "",
# WanToDance
wantodance_music_path: Optional[str] = None,
wantodance_reference_image: Optional[Image.Image] = None,
wantodance_fps: Optional[float] = 30,
wantodance_keyframes: Optional[list[Image.Image]] = None,
wantodance_keyframes_mask: Optional[list[int]] = None,
wantodance_music_path: str = None,
wantodance_reference_image: Image.Image = None,
wantodance_fps: float = 30,
wantodance_keyframes: list[Image.Image] = None,
wantodance_keyframes_mask: list[int] = None,
framewise_decoding: bool = False,
# progress_bar
progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
output_type: Literal["quantized", "floatpoint"] = "quantized",
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)

View File

@@ -95,7 +95,7 @@ class ZImagePipeline(BasePipeline):
def __call__(
self,
# Prompt
prompt: str,
prompt: str = "",
negative_prompt: str = "",
cfg_scale: float = 1.0,
# Image
@@ -109,7 +109,7 @@ class ZImagePipeline(BasePipeline):
width: int = 1024,
# Randomness
seed: int = None,
rand_device: str = "cpu",
rand_device: Union[str, torch.device] = "cpu",
# Steps
num_inference_steps: int = 8,
sigma_shift: float = None,

321
examples/dev_tools/webui.py Normal file
View File

@@ -0,0 +1,321 @@
import importlib, inspect, pkgutil, traceback, torch, os, re, typing
from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal
from contextlib import contextmanager
from diffsynth.utils.data import VideoData
import streamlit as st
from diffsynth import ModelConfig
from diffsynth.diffusion.base_pipeline import ControlNetInput
from PIL import Image
from tqdm import tqdm
st.set_page_config(layout="wide")
class StreamlitTqdmWrapper:
"""Wrapper class that combines tqdm and streamlit progress bar"""
def __init__(self, iterable, st_progress_bar=None):
self.iterable = iterable
self.st_progress_bar = st_progress_bar
self.tqdm_bar = tqdm(iterable)
self.total = len(iterable) if hasattr(iterable, '__len__') else None
self.current = 0
def __iter__(self):
for item in self.tqdm_bar:
if self.st_progress_bar is not None and self.total is not None:
self.current += 1
self.st_progress_bar.progress(self.current / self.total)
yield item
def __enter__(self):
return self
def __exit__(self, *args):
if hasattr(self.tqdm_bar, '__exit__'):
self.tqdm_bar.__exit__(*args)
@contextmanager
def catch_error(error_value):
try:
yield
except Exception as e:
error_message = traceback.format_exc()
print(f"Error {error_value}:\n{error_message}")
def parse_model_configs_from_an_example(path):
model_configs = []
with open(path, "r") as f:
for code in f.readlines():
code = code.strip()
if not code.startswith("ModelConfig"):
continue
pairs = re.findall(r'(\w+)\s*=\s*["\']([^"\']+)["\']', code)
config_dict = {k: v for k, v in pairs}
model_configs.append(ModelConfig(model_id=config_dict["model_id"], origin_file_pattern=config_dict["origin_file_pattern"]))
return model_configs
def list_examples(path, keyword=None):
examples = []
if os.path.isdir(path):
for file_name in os.listdir(path):
examples.extend(list_examples(os.path.join(path, file_name), keyword=keyword))
elif path.endswith(".py"):
with open(path, "r") as f:
code = f.read()
if keyword is None or keyword in code:
examples.extend([path])
return examples
def parse_available_pipelines():
from diffsynth.diffusion.base_pipeline import BasePipeline
import diffsynth.pipelines as _pipelines_pkg
available_pipelines = {}
for _, name, _ in pkgutil.iter_modules(_pipelines_pkg.__path__):
with catch_error(f"Failed: import diffsynth.pipelines.{name}"):
mod = importlib.import_module(f"diffsynth.pipelines.{name}")
classes = {
cls_name: cls for cls_name, cls in inspect.getmembers(mod, inspect.isclass)
if issubclass(cls, BasePipeline) and cls is not BasePipeline and cls.__module__ == mod.__name__
}
available_pipelines.update(classes)
return available_pipelines
def parse_available_examples(path, available_pipelines):
available_examples = {}
for pipeline_name in available_pipelines:
examples = ["None"] + list_examples(path, keyword=f"{pipeline_name}.from_pretrained")
available_examples[pipeline_name] = examples
return available_examples
def draw_selectbox(label, options, option_map, value=None, disabled=False):
default_index = 0 if value is None else tuple(options).index([option for option in option_map if option_map[option]==value][0])
option = st.selectbox(label=label, options=tuple(options), index=default_index, disabled=disabled)
return option_map.get(option)
def parse_params(fn):
params = []
for name, param in inspect.signature(fn).parameters.items():
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else None
default = param.default if param.default is not inspect.Parameter.empty else None
params.append({"name": name, "dtype": annotation, "value": default})
return params
def draw_model_config(model_config=None, key_suffix="", disabled=False):
with st.container(border=True):
if model_config is None:
model_config = ModelConfig()
path = st.text_input(label="path", key="path" + key_suffix, value=model_config.path, disabled=disabled)
col1, col2 = st.columns(2)
with col1:
model_id = st.text_input(label="model_id", key="model_id" + key_suffix, value=model_config.model_id, disabled=disabled)
with col2:
origin_file_pattern = st.text_input(label="origin_file_pattern", key="origin_file_pattern" + key_suffix, value=model_config.origin_file_pattern, disabled=disabled)
model_config = ModelConfig(
path=None if path == "" else path,
model_id=model_id,
origin_file_pattern=origin_file_pattern,
)
return model_config
def draw_multi_model_config(name="", value=None, disabled=False):
model_configs = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
model_config = draw_model_config(key_suffix=f"_{name}_{i}", model_config=None if value is None else value[i], disabled=disabled)
model_configs.append(model_config)
return model_configs
def draw_single_model_config(name="", value=None, disabled=False):
with st.container(border=True):
st.markdown(name)
model_config = draw_model_config(value, key_suffix=f"_{name}", disabled=disabled)
return model_config
def draw_multi_images(name="", value=None, disabled=False):
images = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
image = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], key=f"{name}_{i}", disabled=disabled)
if image is not None: images.append(Image.open(image))
return images
def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None):
if kwargs is None:
kwargs = {}
elements = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs)
elements.append(element)
return elements
def draw_controlnet_input(name="", value=None, disabled=False):
with st.container(border=True):
st.markdown(name)
controlnet_id = st.number_input("controlnet_id", value=0, min_value=0, max_value=20, step=1, key=f"{name}_controlnet_id")
scale = st.number_input("scale", value=1.0, min_value=0.0, max_value=10.0, key=f"{name}_scale")
image = st.file_uploader("image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_image")
if image is not None: image = Image.open(image)
inpaint_image = st.file_uploader("inpaint_image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_image")
if inpaint_image is not None: inpaint_image = Image.open(inpaint_image)
inpaint_mask = st.file_uploader("inpaint_mask", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_mask")
if inpaint_mask is not None: inpaint_mask = Image.open(inpaint_mask)
return ControlNetInput(controlnet_id=controlnet_id, scale=scale, image=image, inpaint_image=inpaint_image, inpaint_mask=inpaint_mask)
def draw_controlnet_inputs(name, value=None, disabled=False):
controlnet_inputs = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
controlnet_input = draw_controlnet_input(name=f"{name}_{i}", value=None, disabled=disabled)
controlnet_inputs.append(controlnet_input)
return controlnet_inputs
def draw_ui_element(name, dtype, value):
unsupported_dtype = [
Dict[str, torch.Tensor],
torch.Tensor,
]
if dtype in unsupported_dtype:
return
if value is None:
with st.container(border=True):
enable = st.checkbox(f"Enable {name}", value=False)
ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable)
if enable:
return ui
else:
return None
else:
return draw_ui_element_safely(name, dtype, value)
def draw_video(name, value=None, disabled=False):
ui = st.file_uploader(name, type=["mp4"], disabled=disabled)
if ui is not None:
ui = VideoData(ui)
ui = [ui[i] for i in range(len(ui))]
return ui
def draw_ui_element_safely(name, dtype, value, disabled=False):
if dtype == torch.dtype:
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == Union[str, torch.device]:
option_map = {"cuda": "cuda", "cpu": "cpu"}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == bool:
ui = st.checkbox(name, value=value, disabled=disabled)
elif dtype == ModelConfig:
ui = draw_single_model_config(name, value=value, disabled=disabled)
elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]:
if name == "model_configs" and "model_configs_from_example" in st.session_state:
model_configs = st.session_state["model_configs_from_example"]
del st.session_state["model_configs_from_example"]
ui = draw_multi_model_config(name, model_configs, disabled=disabled)
else:
ui = draw_multi_model_config(name, disabled=disabled)
elif dtype == str:
if "prompt" in name:
ui = st.text_area(name, value=value, height=3, disabled=disabled)
else:
ui = st.text_input(name, value=value, disabled=disabled)
elif dtype == float:
ui = st.number_input(name, value=value, disabled=disabled)
elif dtype == int:
ui = st.number_input(name, value=value, step=1, disabled=disabled)
elif dtype == Image.Image:
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
if ui is not None: ui = Image.open(ui)
elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]:
if "video" in name:
ui = draw_video(name, value=value, disabled=disabled)
else:
ui = draw_multi_images(name, value=value, disabled=disabled)
elif dtype in [List[ControlNetInput], list[ControlNetInput]]:
ui = draw_controlnet_inputs(name, value=value, disabled=disabled)
elif dtype in [List[str], list[str]]:
ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled)
elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled)
elif dtype in [List[int], list[int]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1})
elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]:
ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled)
elif dtype in [tuple[int, int], Tuple[int, int]]:
with st.container(border=True):
st.markdown(name)
ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled))
elif isinstance(dtype, typing._LiteralGenericAlias):
with st.container(border=True):
st.markdown(f"{name} ({dtype})")
ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden")
elif dtype is None:
if name == "progress_bar_cmd":
ui = value
else:
st.markdown(f"(`{name}` is not not configurable in WebUI). dtype: `{dtype}`.")
ui = value
return ui
def launch_webui():
input_col, output_col = st.columns(2)
with input_col:
if "available_pipelines" not in st.session_state:
st.session_state["available_pipelines"] = parse_available_pipelines()
if "available_examples" not in st.session_state:
st.session_state["available_examples"] = parse_available_examples("./examples", st.session_state["available_pipelines"])
with st.expander("Pipeline", expanded=True):
pipeline_class = draw_selectbox("Pipeline Class", st.session_state["available_pipelines"].keys(), st.session_state["available_pipelines"], value=st.session_state["available_pipelines"]["ZImagePipeline"])
example = st.selectbox("Parse model configs from an example (optional)", st.session_state["available_examples"][pipeline_class.__name__])
if example != "None":
st.session_state["model_configs_from_example"] = parse_model_configs_from_an_example(example)
if st.button("Step 1: Parse Pipeline", type="primary"):
st.session_state["pipeline_class"] = pipeline_class
if "pipeline_class" not in st.session_state:
return
with st.expander("Model", expanded=True):
input_params = {}
params = parse_params(pipeline_class.from_pretrained)
for param in params:
input_params[param["name"]] = draw_ui_element(**param)
if st.button("Step 2: Load Models", type="primary"):
with st.spinner("Loading models", show_time=True):
if "pipe" in st.session_state:
del st.session_state["pipe"]
torch.cuda.empty_cache()
st.session_state["pipe"] = pipeline_class.from_pretrained(**input_params)
if "pipe" not in st.session_state:
return
with st.expander("Input", expanded=True):
pipe = st.session_state["pipe"]
input_params = {}
params = parse_params(pipeline_class.__call__)
for param in params:
if param["name"] in ["self"]:
continue
input_params[param["name"]] = draw_ui_element(**param)
with output_col:
if st.button("Step 3: Generate", type="primary"):
if "progress_bar_cmd" in input_params:
input_params["progress_bar_cmd"] = lambda iterable: StreamlitTqdmWrapper(iterable, st.progress(0))
result = pipe(**input_params)
st.session_state["result"] = result
if "result" in st.session_state:
result = st.session_state["result"]
if isinstance(result, Image.Image):
st.image(result)
else:
print(f"unsupported result format: {result}")
launch_webui()