From b1af4af8a90c1bee472d7358f58e86bb039930a5 Mon Sep 17 00:00:00 2001 From: Zhongjie Duan <35051019+Artiprocher@users.noreply.github.com> Date: Fri, 24 Apr 2026 10:13:36 +0800 Subject: [PATCH] support inference webui tools (#1409) support inference webui tools --- diffsynth/models/dinov3_image_encoder.py | 6 +- diffsynth/models/siglip2_image_encoder.py | 6 +- diffsynth/pipelines/anima_image.py | 2 +- diffsynth/pipelines/flux2_image.py | 4 +- diffsynth/pipelines/flux_image.py | 7 +- diffsynth/pipelines/ltx2_audio_video.py | 58 ++-- diffsynth/pipelines/mova_audio_video.py | 36 +-- diffsynth/pipelines/qwen_image.py | 2 +- diffsynth/pipelines/wan_video.py | 106 +++---- diffsynth/pipelines/z_image.py | 4 +- examples/dev_tools/webui.py | 331 ++++++++++++++++++++++ 11 files changed, 444 insertions(+), 118 deletions(-) create mode 100644 examples/dev_tools/webui.py diff --git a/diffsynth/models/dinov3_image_encoder.py b/diffsynth/models/dinov3_image_encoder.py index c394a03..052f856 100644 --- a/diffsynth/models/dinov3_image_encoder.py +++ b/diffsynth/models/dinov3_image_encoder.py @@ -1,4 +1,4 @@ -from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast +from transformers import DINOv3ViTModel, DINOv3ViTImageProcessor from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig import torch @@ -40,7 +40,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel): value_bias = False ) super().__init__(config) - self.processor = DINOv3ViTImageProcessorFast( + self.processor = DINOv3ViTImageProcessor( crop_size = None, data_format = "channels_first", default_to_square = True, @@ -56,7 +56,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel): 0.456, 0.406 ], - image_processor_type = "DINOv3ViTImageProcessorFast", + image_processor_type = "DINOv3ViTImageProcessor", image_std = [ 0.229, 0.224, diff --git a/diffsynth/models/siglip2_image_encoder.py b/diffsynth/models/siglip2_image_encoder.py index 509eff4..58e1d15 100644 --- a/diffsynth/models/siglip2_image_encoder.py +++ b/diffsynth/models/siglip2_image_encoder.py @@ -1,5 +1,5 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig -from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast +from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessor import torch from diffsynth.core.device.npu_compatible_device import get_device_type @@ -90,7 +90,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel): transformers_version = "4.57.1" ) super().__init__(config) - self.processor = Siglip2ImageProcessorFast( + self.processor = Siglip2ImageProcessor( **{ "data_format": "channels_first", "default_to_square": True, @@ -106,7 +106,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel): 0.5, 0.5 ], - "image_processor_type": "Siglip2ImageProcessorFast", + "image_processor_type": "Siglip2ImageProcessor", "image_std": [ 0.5, 0.5, diff --git a/diffsynth/pipelines/anima_image.py b/diffsynth/pipelines/anima_image.py index bc4f6cd..e29b6f9 100644 --- a/diffsynth/pipelines/anima_image.py +++ b/diffsynth/pipelines/anima_image.py @@ -74,7 +74,7 @@ class AnimaImagePipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", cfg_scale: float = 4.0, # Image diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 7b6dcc4..fae29b6 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -75,7 +75,7 @@ class Flux2ImagePipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", cfg_scale: float = 1.0, embedded_guidance: float = 4.0, @@ -83,7 +83,7 @@ class Flux2ImagePipeline(BasePipeline): input_image: Image.Image = None, denoising_strength: float = 1.0, # Edit - edit_image: Union[Image.Image, List[Image.Image]] = None, + edit_image: List[Image.Image] = None, edit_image_auto_resize: bool = True, # Shape height: int = 1024, diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index db2d522..705012a 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -181,7 +181,7 @@ class FluxImagePipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", cfg_scale: float = 1.0, embedded_guidance: float = 3.5, @@ -199,10 +199,6 @@ class FluxImagePipeline(BasePipeline): sigma_shift: float = None, # Steps num_inference_steps: int = 30, - # local prompts - multidiffusion_prompts=(), - multidiffusion_masks=(), - multidiffusion_scales=(), # Kontext kontext_images: Union[list[Image.Image], Image.Image] = None, # ControlNet @@ -257,7 +253,6 @@ class FluxImagePipeline(BasePipeline): "height": height, "width": width, "seed": seed, "rand_device": rand_device, "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps, - "multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales, "kontext_images": kontext_images, "controlnet_inputs": controlnet_inputs, "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale, diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 1263b43..10723e3 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -169,46 +169,46 @@ class LTX2AudioVideoPipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, - negative_prompt: Optional[str] = "", + prompt: str = "", + negative_prompt: str = "", denoising_strength: float = 1.0, # Image-to-video - input_images: Optional[list[Image.Image]] = None, - input_images_indexes: Optional[list[int]] = [0], - input_images_strength: Optional[float] = 1.0, + input_images: list[Image.Image] = None, + input_images_indexes: list[int] = [0], + input_images_strength: float = 1.0, # In-Context Video Control - in_context_videos: Optional[list[list[Image.Image]]] = None, - in_context_downsample_factor: Optional[int] = 2, + in_context_videos: list[list[Image.Image]] = None, + in_context_downsample_factor: int = 2, # Video-to-video - retake_video: Optional[list[Image.Image]] = None, - retake_video_regions: Optional[list[tuple[float, float]]] = None, + retake_video: list[Image.Image] = None, + retake_video_regions: list[tuple[float, float]] = None, # Audio-to-video - retake_audio: Optional[torch.Tensor] = None, - audio_sample_rate: Optional[int] = 48000, - retake_audio_regions: Optional[list[tuple[float, float]]] = None, + retake_audio: torch.Tensor = None, + audio_sample_rate: int = 48000, + retake_audio_regions: list[tuple[float, float]] = None, # Randomness - seed: Optional[int] = None, - rand_device: Optional[str] = "cpu", + seed: int = None, + rand_device: str = "cpu", # Shape - height: Optional[int] = 512, - width: Optional[int] = 768, - num_frames: Optional[int] = 121, - frame_rate: Optional[int] = 24, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: int = 24, # Classifier-free guidance - cfg_scale: Optional[float] = 3.0, + cfg_scale: float = 3.0, # Scheduler - num_inference_steps: Optional[int] = 30, + num_inference_steps: int = 30, # VAE tiling - tiled: Optional[bool] = True, - tile_size_in_pixels: Optional[int] = 512, - tile_overlap_in_pixels: Optional[int] = 128, - tile_size_in_frames: Optional[int] = 128, - tile_overlap_in_frames: Optional[int] = 24, + tiled: bool = True, + tile_size_in_pixels: int = 512, + tile_overlap_in_pixels: int = 128, + tile_size_in_frames: int = 128, + tile_overlap_in_frames: int = 24, # Special Pipelines - use_two_stage_pipeline: Optional[bool] = False, - stage2_spatial_upsample_factor: Optional[int] = 2, - clear_lora_before_state_two: Optional[bool] = False, - use_distilled_pipeline: Optional[bool] = False, + use_two_stage_pipeline: bool = False, + stage2_spatial_upsample_factor: int = 2, + clear_lora_before_state_two: bool = False, + use_distilled_pipeline: bool = False, # progress_bar progress_bar_cmd=tqdm, ): diff --git a/diffsynth/pipelines/mova_audio_video.py b/diffsynth/pipelines/mova_audio_video.py index d89d3ff..aa83440 100644 --- a/diffsynth/pipelines/mova_audio_video.py +++ b/diffsynth/pipelines/mova_audio_video.py @@ -115,33 +115,33 @@ class MovaAudioVideoPipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, - negative_prompt: Optional[str] = "", + prompt: str = "", + negative_prompt: str = "", # Image-to-video - input_image: Optional[Image.Image] = None, + input_image: Image.Image = None, # First-last-frame-to-video - end_image: Optional[Image.Image] = None, + end_image: Image.Image = None, # Video-to-video - denoising_strength: Optional[float] = 1.0, + denoising_strength: float = 1.0, # Randomness - seed: Optional[int] = None, - rand_device: Optional[str] = "cpu", + seed: int = None, + rand_device: str = "cpu", # Shape - height: Optional[int] = 352, - width: Optional[int] = 640, - num_frames: Optional[int] = 81, - frame_rate: Optional[int] = 24, + height: int = 352, + width: int = 640, + num_frames: int = 81, + frame_rate: int = 24, # Classifier-free guidance - cfg_scale: Optional[float] = 5.0, + cfg_scale: float = 5.0, # Boundary - switch_DiT_boundary: Optional[float] = 0.9, + switch_DiT_boundary: float = 0.9, # Scheduler - num_inference_steps: Optional[int] = 50, - sigma_shift: Optional[float] = 5.0, + num_inference_steps: int = 50, + sigma_shift: float = 5.0, # VAE tiling - tiled: Optional[bool] = True, - tile_size: Optional[tuple[int, int]] = (30, 52), - tile_stride: Optional[tuple[int, int]] = (15, 26), + tiled: bool = True, + tile_size: tuple[int, int] = (30, 52), + tile_stride: tuple[int, int] = (15, 26), # progress_bar progress_bar_cmd=tqdm, ): diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index f3256a1..ad6bb31 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -100,7 +100,7 @@ class QwenImagePipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", cfg_scale: float = 4.0, # Image diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 1c1aa7e..a0a8f7e 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -190,82 +190,82 @@ class WanVideoPipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, - negative_prompt: Optional[str] = "", + prompt: str = "", + negative_prompt: str = "", # Image-to-video - input_image: Optional[Image.Image] = None, + input_image: Image.Image = None, # First-last-frame-to-video - end_image: Optional[Image.Image] = None, + end_image: Image.Image = None, # Video-to-video - input_video: Optional[list[Image.Image]] = None, - denoising_strength: Optional[float] = 1.0, + input_video: list[Image.Image] = None, + denoising_strength: float = 1.0, # Speech-to-video - input_audio: Optional[np.array] = None, - audio_embeds: Optional[torch.Tensor] = None, - audio_sample_rate: Optional[int] = 16000, - s2v_pose_video: Optional[list[Image.Image]] = None, - s2v_pose_latents: Optional[torch.Tensor] = None, - motion_video: Optional[list[Image.Image]] = None, + input_audio: np.array = None, + audio_embeds: torch.Tensor = None, + audio_sample_rate: int = 16000, + s2v_pose_video: list[Image.Image] = None, + s2v_pose_latents: torch.Tensor = None, + motion_video: list[Image.Image] = None, # ControlNet - control_video: Optional[list[Image.Image]] = None, - reference_image: Optional[Image.Image] = None, + control_video: list[Image.Image] = None, + reference_image: Image.Image = None, # Camera control - camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, - camera_control_speed: Optional[float] = 1/54, - camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), + camera_control_direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"] = None, + camera_control_speed: float = 1/54, + camera_control_origin: tuple = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), # VACE - vace_video: Optional[list[Image.Image]] = None, - vace_video_mask: Optional[Image.Image] = None, - vace_reference_image: Optional[Image.Image] = None, - vace_scale: Optional[float] = 1.0, + vace_video: list[Image.Image] = None, + vace_video_mask: Image.Image = None, + vace_reference_image: Image.Image = None, + vace_scale: float = 1.0, # Animate - animate_pose_video: Optional[list[Image.Image]] = None, - animate_face_video: Optional[list[Image.Image]] = None, - animate_inpaint_video: Optional[list[Image.Image]] = None, - animate_mask_video: Optional[list[Image.Image]] = None, + animate_pose_video: list[Image.Image] = None, + animate_face_video: list[Image.Image] = None, + animate_inpaint_video: list[Image.Image] = None, + animate_mask_video: list[Image.Image] = None, # VAP - vap_video: Optional[list[Image.Image]] = None, - vap_prompt: Optional[str] = " ", - negative_vap_prompt: Optional[str] = " ", + vap_video: list[Image.Image] = None, + vap_prompt: str = " ", + negative_vap_prompt: str = " ", # Randomness - seed: Optional[int] = None, - rand_device: Optional[str] = "cpu", + seed: int = None, + rand_device: str = "cpu", # Shape - height: Optional[int] = 480, - width: Optional[int] = 832, - num_frames=81, + height: int = 480, + width: int = 832, + num_frames: int = 81, # Classifier-free guidance - cfg_scale: Optional[float] = 5.0, - cfg_merge: Optional[bool] = False, + cfg_scale: float = 5.0, + cfg_merge: bool = False, # Boundary - switch_DiT_boundary: Optional[float] = 0.875, + switch_DiT_boundary: float = 0.875, # Scheduler - num_inference_steps: Optional[int] = 50, - sigma_shift: Optional[float] = 5.0, + num_inference_steps: int = 50, + sigma_shift: float = 5.0, # Speed control - motion_bucket_id: Optional[int] = None, + motion_bucket_id: int = None, # LongCat-Video - longcat_video: Optional[list[Image.Image]] = None, + longcat_video: list[Image.Image] = None, # VAE tiling - tiled: Optional[bool] = True, - tile_size: Optional[tuple[int, int]] = (30, 52), - tile_stride: Optional[tuple[int, int]] = (15, 26), + tiled: bool = True, + tile_size: tuple[int, int] = (30, 52), + tile_stride: tuple[int, int] = (15, 26), # Sliding window - sliding_window_size: Optional[int] = None, - sliding_window_stride: Optional[int] = None, + sliding_window_size: int = None, + sliding_window_stride: int = None, # Teacache - tea_cache_l1_thresh: Optional[float] = None, - tea_cache_model_id: Optional[str] = "", + tea_cache_l1_thresh: float = None, + tea_cache_model_id: str = "", # WanToDance - wantodance_music_path: Optional[str] = None, - wantodance_reference_image: Optional[Image.Image] = None, - wantodance_fps: Optional[float] = 30, - wantodance_keyframes: Optional[list[Image.Image]] = None, - wantodance_keyframes_mask: Optional[list[int]] = None, + wantodance_music_path: str = None, + wantodance_reference_image: Image.Image = None, + wantodance_fps: float = 30, + wantodance_keyframes: list[Image.Image] = None, + wantodance_keyframes_mask: list[int] = None, framewise_decoding: bool = False, # progress_bar progress_bar_cmd=tqdm, - output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized", + output_type: Literal["quantized", "floatpoint"] = "quantized", ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) diff --git a/diffsynth/pipelines/z_image.py b/diffsynth/pipelines/z_image.py index 59e44b3..79d82d8 100644 --- a/diffsynth/pipelines/z_image.py +++ b/diffsynth/pipelines/z_image.py @@ -95,7 +95,7 @@ class ZImagePipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", cfg_scale: float = 1.0, # Image @@ -109,7 +109,7 @@ class ZImagePipeline(BasePipeline): width: int = 1024, # Randomness seed: int = None, - rand_device: str = "cpu", + rand_device: Union[str, torch.device] = "cpu", # Steps num_inference_steps: int = 8, sigma_shift: float = None, diff --git a/examples/dev_tools/webui.py b/examples/dev_tools/webui.py new file mode 100644 index 0000000..a36a63f --- /dev/null +++ b/examples/dev_tools/webui.py @@ -0,0 +1,331 @@ +import importlib, inspect, pkgutil, traceback, torch, os, re, typing +from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal +from contextlib import contextmanager +from diffsynth.utils.data import VideoData +import streamlit as st +from diffsynth import ModelConfig +from diffsynth.diffusion.base_pipeline import ControlNetInput +from PIL import Image +from tqdm import tqdm +st.set_page_config(layout="wide") + +class StreamlitTqdmWrapper: + """Wrapper class that combines tqdm and streamlit progress bar""" + def __init__(self, iterable, st_progress_bar=None): + self.iterable = iterable + self.st_progress_bar = st_progress_bar + self.tqdm_bar = tqdm(iterable) + self.total = len(iterable) if hasattr(iterable, '__len__') else None + self.current = 0 + + def __iter__(self): + for item in self.tqdm_bar: + if self.st_progress_bar is not None and self.total is not None: + self.current += 1 + self.st_progress_bar.progress(self.current / self.total) + yield item + + def __enter__(self): + return self + + def __exit__(self, *args): + if hasattr(self.tqdm_bar, '__exit__'): + self.tqdm_bar.__exit__(*args) + +@contextmanager +def catch_error(error_value): + try: + yield + except Exception as e: + error_message = traceback.format_exc() + print(f"Error {error_value}:\n{error_message}") + +def parse_model_configs_from_an_example(path): + model_configs = [] + with open(path, "r") as f: + for code in f.readlines(): + code = code.strip() + if not code.startswith("ModelConfig"): + continue + pairs = re.findall(r'(\w+)\s*=\s*["\']([^"\']+)["\']', code) + config_dict = {k: v for k, v in pairs} + model_configs.append(ModelConfig(model_id=config_dict["model_id"], origin_file_pattern=config_dict["origin_file_pattern"])) + return model_configs + +def list_examples(path, keyword=None): + examples = [] + if os.path.isdir(path): + for file_name in os.listdir(path): + examples.extend(list_examples(os.path.join(path, file_name), keyword=keyword)) + elif path.endswith(".py"): + with open(path, "r") as f: + code = f.read() + if keyword is None or keyword in code: + examples.extend([path]) + return examples + +def parse_available_pipelines(): + from diffsynth.diffusion.base_pipeline import BasePipeline + import diffsynth.pipelines as _pipelines_pkg + available_pipelines = {} + for _, name, _ in pkgutil.iter_modules(_pipelines_pkg.__path__): + with catch_error(f"Failed: import diffsynth.pipelines.{name}"): + mod = importlib.import_module(f"diffsynth.pipelines.{name}") + classes = { + cls_name: cls for cls_name, cls in inspect.getmembers(mod, inspect.isclass) + if issubclass(cls, BasePipeline) and cls is not BasePipeline and cls.__module__ == mod.__name__ + } + available_pipelines.update(classes) + return available_pipelines + +def parse_available_examples(path, available_pipelines): + available_examples = {} + for pipeline_name in available_pipelines: + examples = ["None"] + list_examples(path, keyword=f"{pipeline_name}.from_pretrained") + available_examples[pipeline_name] = examples + return available_examples + +def draw_selectbox(label, options, option_map, value=None, disabled=False): + default_index = 0 if value is None else tuple(options).index([option for option in option_map if option_map[option]==value][0]) + option = st.selectbox(label=label, options=tuple(options), index=default_index, disabled=disabled) + return option_map.get(option) + +def parse_params(fn): + params = [] + for name, param in inspect.signature(fn).parameters.items(): + annotation = param.annotation if param.annotation is not inspect.Parameter.empty else None + default = param.default if param.default is not inspect.Parameter.empty else None + params.append({"name": name, "dtype": annotation, "value": default}) + return params + +def draw_model_config(model_config=None, key_suffix="", disabled=False): + with st.container(border=True): + if model_config is None: + model_config = ModelConfig() + path = st.text_input(label="path", key="path" + key_suffix, value=model_config.path, disabled=disabled) + col1, col2 = st.columns(2) + with col1: + model_id = st.text_input(label="model_id", key="model_id" + key_suffix, value=model_config.model_id, disabled=disabled) + with col2: + origin_file_pattern = st.text_input(label="origin_file_pattern", key="origin_file_pattern" + key_suffix, value=model_config.origin_file_pattern, disabled=disabled) + model_config = ModelConfig( + path=None if path == "" else path, + model_id=model_id, + origin_file_pattern=origin_file_pattern, + ) + return model_config + +def draw_multi_model_config(name="", value=None, disabled=False): + model_configs = [] + with st.container(border=True): + st.markdown(name) + num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled) + for i in range(num): + model_config = draw_model_config(key_suffix=f"_{name}_{i}", model_config=None if value is None else value[i], disabled=disabled) + model_configs.append(model_config) + return model_configs + +def draw_single_model_config(name="", value=None, disabled=False): + with st.container(border=True): + st.markdown(name) + model_config = draw_model_config(value, key_suffix=f"_{name}", disabled=disabled) + return model_config + +def draw_multi_images(name="", value=None, disabled=False): + images = [] + with st.container(border=True): + st.markdown(name) + num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled) + for i in range(num): + image = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], key=f"{name}_{i}", disabled=disabled) + if image is not None: images.append(Image.open(image)) + return images + +def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None): + if kwargs is None: + kwargs = {} + elements = [] + with st.container(border=True): + st.markdown(name) + num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled) + for i in range(num): + element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs) + elements.append(element) + return elements + +def draw_controlnet_input(name="", value=None, disabled=False): + with st.container(border=True): + st.markdown(name) + controlnet_id = st.number_input("controlnet_id", value=0, min_value=0, max_value=20, step=1, key=f"{name}_controlnet_id") + scale = st.number_input("scale", value=1.0, min_value=0.0, max_value=10.0, key=f"{name}_scale") + image = st.file_uploader("image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_image") + if image is not None: image = Image.open(image) + inpaint_image = st.file_uploader("inpaint_image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_image") + if inpaint_image is not None: inpaint_image = Image.open(inpaint_image) + inpaint_mask = st.file_uploader("inpaint_mask", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_mask") + if inpaint_mask is not None: inpaint_mask = Image.open(inpaint_mask) + return ControlNetInput(controlnet_id=controlnet_id, scale=scale, image=image, inpaint_image=inpaint_image, inpaint_mask=inpaint_mask) + +def draw_controlnet_inputs(name, value=None, disabled=False): + controlnet_inputs = [] + with st.container(border=True): + st.markdown(name) + num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled) + for i in range(num): + controlnet_input = draw_controlnet_input(name=f"{name}_{i}", value=None, disabled=disabled) + controlnet_inputs.append(controlnet_input) + return controlnet_inputs + +def draw_ui_element(name, dtype, value): + unsupported_dtype = [ + Dict[str, torch.Tensor], + torch.Tensor, + ] + if dtype in unsupported_dtype: + return + if value is None: + with st.container(border=True): + enable = st.checkbox(f"Enable {name}", value=False) + ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable) + if enable: + return ui + else: + return None + else: + return draw_ui_element_safely(name, dtype, value) + +def draw_video(name, value=None, disabled=False): + ui = st.file_uploader(name, type=["mp4"], disabled=disabled) + if ui is not None: + ui = VideoData(ui) + ui = [ui[i] for i in range(len(ui))] + return ui + +def draw_ui_element_safely(name, dtype, value, disabled=False): + if dtype == torch.dtype: + option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16} + ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled) + elif dtype == Union[str, torch.device]: + option_map = {"cuda": "cuda", "cpu": "cpu"} + ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled) + elif dtype == bool: + ui = st.checkbox(name, value=value, disabled=disabled) + elif dtype == ModelConfig: + ui = draw_single_model_config(name, value=value, disabled=disabled) + elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]: + if name == "model_configs" and "model_configs_from_example" in st.session_state: + model_configs = st.session_state["model_configs_from_example"] + del st.session_state["model_configs_from_example"] + ui = draw_multi_model_config(name, model_configs, disabled=disabled) + else: + ui = draw_multi_model_config(name, disabled=disabled) + elif dtype == str: + if "prompt" in name: + ui = st.text_area(name, value=value, height=3, disabled=disabled) + else: + ui = st.text_input(name, value=value, disabled=disabled) + elif dtype == float: + ui = st.number_input(name, value=value, disabled=disabled) + elif dtype == int: + ui = st.number_input(name, value=value, step=1, disabled=disabled) + elif dtype == Image.Image: + ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled) + if ui is not None: ui = Image.open(ui) + elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]: + if "video" in name: + ui = draw_video(name, value=value, disabled=disabled) + else: + ui = draw_multi_images(name, value=value, disabled=disabled) + elif dtype in [List[ControlNetInput], list[ControlNetInput]]: + ui = draw_controlnet_inputs(name, value=value, disabled=disabled) + elif dtype in [List[str], list[str]]: + ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled) + elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]: + ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled) + elif dtype in [List[int], list[int]]: + ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1}) + elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]: + ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled) + elif dtype in [tuple[int, int], Tuple[int, int]]: + with st.container(border=True): + st.markdown(name) + ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled)) + elif isinstance(dtype, typing._LiteralGenericAlias): + with st.container(border=True): + st.markdown(f"{name} ({dtype})") + ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden") + elif dtype is None: + if name == "progress_bar_cmd": + ui = value + else: + st.markdown(f"(`{name}` is not not configurable in WebUI). dtype: `{dtype}`.") + ui = value + return ui + + +def launch_webui(): + input_col, output_col = st.columns(2) + with input_col: + if "available_pipelines" not in st.session_state: + st.session_state["available_pipelines"] = parse_available_pipelines() + if "available_examples" not in st.session_state: + st.session_state["available_examples"] = parse_available_examples("./examples", st.session_state["available_pipelines"]) + + with st.expander("Pipeline", expanded=True): + pipeline_class = draw_selectbox("Pipeline Class", st.session_state["available_pipelines"].keys(), st.session_state["available_pipelines"], value=st.session_state["available_pipelines"]["ZImagePipeline"]) + example = st.selectbox("Parse model configs from an example (optional)", st.session_state["available_examples"][pipeline_class.__name__]) + + # Clear if pipeline is changed + if "prev_pipeline_class" in st.session_state and st.session_state["prev_pipeline_class"] != pipeline_class: + if "pipeline_class" in st.session_state: del st.session_state["pipeline_class"] + if "model_configs_from_example" in st.session_state: del st.session_state["model_configs_from_example"] + if "prev_example" in st.session_state and st.session_state["prev_example"] != example: + if "model_configs_from_example" in st.session_state: del st.session_state["model_configs_from_example"] + st.session_state["prev_pipeline_class"] = pipeline_class + st.session_state["prev_example"] = example + + if st.button("Step 1: Parse Pipeline", type="primary"): + st.session_state["pipeline_class"] = pipeline_class + if example != "None": + st.session_state["model_configs_from_example"] = parse_model_configs_from_an_example(example) + + if "pipeline_class" not in st.session_state: + return + with st.expander("Model", expanded=True): + input_params = {} + params = parse_params(pipeline_class.from_pretrained) + for param in params: + input_params[param["name"]] = draw_ui_element(**param) + if st.button("Step 2: Load Models", type="primary"): + with st.spinner("Loading models", show_time=True): + if "pipe" in st.session_state: + del st.session_state["pipe"] + torch.cuda.empty_cache() + st.session_state["pipe"] = pipeline_class.from_pretrained(**input_params) + + if "pipe" not in st.session_state: + return + with st.expander("Input", expanded=True): + pipe = st.session_state["pipe"] + input_params = {} + params = parse_params(pipeline_class.__call__) + for param in params: + if param["name"] in ["self"]: + continue + input_params[param["name"]] = draw_ui_element(**param) + + with output_col: + if st.button("Step 3: Generate", type="primary"): + if "progress_bar_cmd" in input_params: + input_params["progress_bar_cmd"] = lambda iterable: StreamlitTqdmWrapper(iterable, st.progress(0)) + result = pipe(**input_params) + st.session_state["result"] = result + + if "result" in st.session_state: + result = st.session_state["result"] + if isinstance(result, Image.Image): + st.image(result) + else: + print(f"unsupported result format: {result}") + +launch_webui()