Compare commits

...

2 Commits

Author SHA1 Message Date
Artiprocher
e2a04139b4 update webui 2026-04-14 15:50:48 +08:00
lzws
db0f1571b1 fix pipeline args 2026-04-14 15:17:19 +08:00
8 changed files with 159 additions and 126 deletions

View File

@@ -74,7 +74,7 @@ class AnimaImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
cfg_scale: float = 4.0, cfg_scale: float = 4.0,
# Image # Image

View File

@@ -75,7 +75,7 @@ class Flux2ImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
cfg_scale: float = 1.0, cfg_scale: float = 1.0,
embedded_guidance: float = 4.0, embedded_guidance: float = 4.0,
@@ -83,7 +83,7 @@ class Flux2ImagePipeline(BasePipeline):
input_image: Image.Image = None, input_image: Image.Image = None,
denoising_strength: float = 1.0, denoising_strength: float = 1.0,
# Edit # Edit
edit_image: Union[Image.Image, List[Image.Image]] = None, edit_image: List[Image.Image] = None,
edit_image_auto_resize: bool = True, edit_image_auto_resize: bool = True,
# Shape # Shape
height: int = 1024, height: int = 1024,

View File

@@ -181,7 +181,7 @@ class FluxImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
cfg_scale: float = 1.0, cfg_scale: float = 1.0,
embedded_guidance: float = 3.5, embedded_guidance: float = 3.5,
@@ -199,10 +199,6 @@ class FluxImagePipeline(BasePipeline):
sigma_shift: float = None, sigma_shift: float = None,
# Steps # Steps
num_inference_steps: int = 30, num_inference_steps: int = 30,
# local prompts
multidiffusion_prompts=(),
multidiffusion_masks=(),
multidiffusion_scales=(),
# Kontext # Kontext
kontext_images: Union[list[Image.Image], Image.Image] = None, kontext_images: Union[list[Image.Image], Image.Image] = None,
# ControlNet # ControlNet
@@ -257,7 +253,6 @@ class FluxImagePipeline(BasePipeline):
"height": height, "width": width, "height": height, "width": width,
"seed": seed, "rand_device": rand_device, "seed": seed, "rand_device": rand_device,
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps, "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
"kontext_images": kontext_images, "kontext_images": kontext_images,
"controlnet_inputs": controlnet_inputs, "controlnet_inputs": controlnet_inputs,
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale, "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,

View File

@@ -169,46 +169,46 @@ class LTX2AudioVideoPipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: Optional[str] = "", negative_prompt: str = "",
denoising_strength: float = 1.0, denoising_strength: float = 1.0,
# Image-to-video # Image-to-video
input_images: Optional[list[Image.Image]] = None, input_images: list[Image.Image] = None,
input_images_indexes: Optional[list[int]] = [0], input_images_indexes: list[int] = [0],
input_images_strength: Optional[float] = 1.0, input_images_strength: float = 1.0,
# In-Context Video Control # In-Context Video Control
in_context_videos: Optional[list[list[Image.Image]]] = None, in_context_videos: list[list[Image.Image]] = None,
in_context_downsample_factor: Optional[int] = 2, in_context_downsample_factor: int = 2,
# Video-to-video # Video-to-video
retake_video: Optional[list[Image.Image]] = None, retake_video: list[Image.Image] = None,
retake_video_regions: Optional[list[tuple[float, float]]] = None, retake_video_regions: list[tuple[float, float]] = None,
# Audio-to-video # Audio-to-video
retake_audio: Optional[torch.Tensor] = None, retake_audio: torch.Tensor = None,
audio_sample_rate: Optional[int] = 48000, audio_sample_rate: int = 48000,
retake_audio_regions: Optional[list[tuple[float, float]]] = None, retake_audio_regions: list[tuple[float, float]] = None,
# Randomness # Randomness
seed: Optional[int] = None, seed: int = None,
rand_device: Optional[str] = "cpu", rand_device: str = "cpu",
# Shape # Shape
height: Optional[int] = 512, height: int = 512,
width: Optional[int] = 768, width: int = 768,
num_frames: Optional[int] = 121, num_frames: int = 121,
frame_rate: Optional[int] = 24, frame_rate: int = 24,
# Classifier-free guidance # Classifier-free guidance
cfg_scale: Optional[float] = 3.0, cfg_scale: float = 3.0,
# Scheduler # Scheduler
num_inference_steps: Optional[int] = 30, num_inference_steps: int = 30,
# VAE tiling # VAE tiling
tiled: Optional[bool] = True, tiled: bool = True,
tile_size_in_pixels: Optional[int] = 512, tile_size_in_pixels: int = 512,
tile_overlap_in_pixels: Optional[int] = 128, tile_overlap_in_pixels: int = 128,
tile_size_in_frames: Optional[int] = 128, tile_size_in_frames: int = 128,
tile_overlap_in_frames: Optional[int] = 24, tile_overlap_in_frames: int = 24,
# Special Pipelines # Special Pipelines
use_two_stage_pipeline: Optional[bool] = False, use_two_stage_pipeline: bool = False,
stage2_spatial_upsample_factor: Optional[int] = 2, stage2_spatial_upsample_factor: int = 2,
clear_lora_before_state_two: Optional[bool] = False, clear_lora_before_state_two: bool = False,
use_distilled_pipeline: Optional[bool] = False, use_distilled_pipeline: bool = False,
# progress_bar # progress_bar
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
): ):

View File

@@ -115,33 +115,33 @@ class MovaAudioVideoPipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: Optional[str] = "", negative_prompt: str = "",
# Image-to-video # Image-to-video
input_image: Optional[Image.Image] = None, input_image: Image.Image = None,
# First-last-frame-to-video # First-last-frame-to-video
end_image: Optional[Image.Image] = None, end_image: Image.Image = None,
# Video-to-video # Video-to-video
denoising_strength: Optional[float] = 1.0, denoising_strength: float = 1.0,
# Randomness # Randomness
seed: Optional[int] = None, seed: int = None,
rand_device: Optional[str] = "cpu", rand_device: str = "cpu",
# Shape # Shape
height: Optional[int] = 352, height: int = 352,
width: Optional[int] = 640, width: int = 640,
num_frames: Optional[int] = 81, num_frames: int = 81,
frame_rate: Optional[int] = 24, frame_rate: int = 24,
# Classifier-free guidance # Classifier-free guidance
cfg_scale: Optional[float] = 5.0, cfg_scale: float = 5.0,
# Boundary # Boundary
switch_DiT_boundary: Optional[float] = 0.9, switch_DiT_boundary: float = 0.9,
# Scheduler # Scheduler
num_inference_steps: Optional[int] = 50, num_inference_steps: int = 50,
sigma_shift: Optional[float] = 5.0, sigma_shift: float = 5.0,
# VAE tiling # VAE tiling
tiled: Optional[bool] = True, tiled: bool = True,
tile_size: Optional[tuple[int, int]] = (30, 52), tile_size: tuple[int, int] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26), tile_stride: tuple[int, int] = (15, 26),
# progress_bar # progress_bar
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
): ):

View File

@@ -100,7 +100,7 @@ class QwenImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: str = "", negative_prompt: str = "",
cfg_scale: float = 4.0, cfg_scale: float = 4.0,
# Image # Image

View File

@@ -190,82 +190,82 @@ class WanVideoPipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str, prompt: str = "",
negative_prompt: Optional[str] = "", negative_prompt: str = "",
# Image-to-video # Image-to-video
input_image: Optional[Image.Image] = None, input_image: Image.Image = None,
# First-last-frame-to-video # First-last-frame-to-video
end_image: Optional[Image.Image] = None, end_image: Image.Image = None,
# Video-to-video # Video-to-video
input_video: Optional[list[Image.Image]] = None, input_video: list[Image.Image] = None,
denoising_strength: Optional[float] = 1.0, denoising_strength: float = 1.0,
# Speech-to-video # Speech-to-video
input_audio: Optional[np.array] = None, input_audio: np.array = None,
audio_embeds: Optional[torch.Tensor] = None, audio_embeds: torch.Tensor = None,
audio_sample_rate: Optional[int] = 16000, audio_sample_rate: int = 16000,
s2v_pose_video: Optional[list[Image.Image]] = None, s2v_pose_video: list[Image.Image] = None,
s2v_pose_latents: Optional[torch.Tensor] = None, s2v_pose_latents: torch.Tensor = None,
motion_video: Optional[list[Image.Image]] = None, motion_video: list[Image.Image] = None,
# ControlNet # ControlNet
control_video: Optional[list[Image.Image]] = None, control_video: list[Image.Image] = None,
reference_image: Optional[Image.Image] = None, reference_image: Image.Image = None,
# Camera control # Camera control
camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None, camera_control_direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"] = None,
camera_control_speed: Optional[float] = 1/54, camera_control_speed: float = 1/54,
camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0), camera_control_origin: tuple = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
# VACE # VACE
vace_video: Optional[list[Image.Image]] = None, vace_video: list[Image.Image] = None,
vace_video_mask: Optional[Image.Image] = None, vace_video_mask: Image.Image = None,
vace_reference_image: Optional[Image.Image] = None, vace_reference_image: Image.Image = None,
vace_scale: Optional[float] = 1.0, vace_scale: float = 1.0,
# Animate # Animate
animate_pose_video: Optional[list[Image.Image]] = None, animate_pose_video: list[Image.Image] = None,
animate_face_video: Optional[list[Image.Image]] = None, animate_face_video: list[Image.Image] = None,
animate_inpaint_video: Optional[list[Image.Image]] = None, animate_inpaint_video: list[Image.Image] = None,
animate_mask_video: Optional[list[Image.Image]] = None, animate_mask_video: list[Image.Image] = None,
# VAP # VAP
vap_video: Optional[list[Image.Image]] = None, vap_video: list[Image.Image] = None,
vap_prompt: Optional[str] = " ", vap_prompt: str = " ",
negative_vap_prompt: Optional[str] = " ", negative_vap_prompt: str = " ",
# Randomness # Randomness
seed: Optional[int] = None, seed: int = None,
rand_device: Optional[str] = "cpu", rand_device: str = "cpu",
# Shape # Shape
height: Optional[int] = 480, height: int = 480,
width: Optional[int] = 832, width: int = 832,
num_frames=81, num_frames: int = 81,
# Classifier-free guidance # Classifier-free guidance
cfg_scale: Optional[float] = 5.0, cfg_scale: float = 5.0,
cfg_merge: Optional[bool] = False, cfg_merge: bool = False,
# Boundary # Boundary
switch_DiT_boundary: Optional[float] = 0.875, switch_DiT_boundary: float = 0.875,
# Scheduler # Scheduler
num_inference_steps: Optional[int] = 50, num_inference_steps: int = 50,
sigma_shift: Optional[float] = 5.0, sigma_shift: float = 5.0,
# Speed control # Speed control
motion_bucket_id: Optional[int] = None, motion_bucket_id: int = None,
# LongCat-Video # LongCat-Video
longcat_video: Optional[list[Image.Image]] = None, longcat_video: list[Image.Image] = None,
# VAE tiling # VAE tiling
tiled: Optional[bool] = True, tiled: bool = True,
tile_size: Optional[tuple[int, int]] = (30, 52), tile_size: tuple[int, int] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26), tile_stride: tuple[int, int] = (15, 26),
# Sliding window # Sliding window
sliding_window_size: Optional[int] = None, sliding_window_size: int = None,
sliding_window_stride: Optional[int] = None, sliding_window_stride: int = None,
# Teacache # Teacache
tea_cache_l1_thresh: Optional[float] = None, tea_cache_l1_thresh: float = None,
tea_cache_model_id: Optional[str] = "", tea_cache_model_id: str = "",
# WanToDance # WanToDance
wantodance_music_path: Optional[str] = None, wantodance_music_path: str = None,
wantodance_reference_image: Optional[Image.Image] = None, wantodance_reference_image: Image.Image = None,
wantodance_fps: Optional[float] = 30, wantodance_fps: float = 30,
wantodance_keyframes: Optional[list[Image.Image]] = None, wantodance_keyframes: list[Image.Image] = None,
wantodance_keyframes_mask: Optional[list[int]] = None, wantodance_keyframes_mask: list[int] = None,
framewise_decoding: bool = False, framewise_decoding: bool = False,
# progress_bar # progress_bar
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized", output_type: Literal["quantized", "floatpoint"] = "quantized",
): ):
# Scheduler # Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift) self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)

View File

@@ -1,6 +1,7 @@
import importlib, inspect, pkgutil, traceback, torch, os, re import importlib, inspect, pkgutil, traceback, torch, os, re, typing
from typing import Union, List, Optional, Tuple, Iterable, Dict from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal
from contextlib import contextmanager from contextlib import contextmanager
from diffsynth.utils.data import VideoData
import streamlit as st import streamlit as st
from diffsynth import ModelConfig from diffsynth import ModelConfig
from diffsynth.diffusion.base_pipeline import ControlNetInput from diffsynth.diffusion.base_pipeline import ControlNetInput
@@ -140,6 +141,18 @@ def draw_multi_images(name="", value=None, disabled=False):
if image is not None: images.append(Image.open(image)) if image is not None: images.append(Image.open(image))
return images return images
def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None):
if kwargs is None:
kwargs = {}
elements = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs)
elements.append(element)
return elements
def draw_controlnet_input(name="", value=None, disabled=False): def draw_controlnet_input(name="", value=None, disabled=False):
with st.container(border=True): with st.container(border=True):
st.markdown(name) st.markdown(name)
@@ -173,7 +186,7 @@ def draw_ui_element(name, dtype, value):
if value is None: if value is None:
with st.container(border=True): with st.container(border=True):
enable = st.checkbox(f"Enable {name}", value=False) enable = st.checkbox(f"Enable {name}", value=False)
ui = draw_ui_element_safely(name, dtype, value, disabled=not enable) ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable)
if enable: if enable:
return ui return ui
else: else:
@@ -181,6 +194,13 @@ def draw_ui_element(name, dtype, value):
else: else:
return draw_ui_element_safely(name, dtype, value) return draw_ui_element_safely(name, dtype, value)
def draw_video(name, value=None, disabled=False):
ui = st.file_uploader(name, type=["mp4"], disabled=disabled)
if ui is not None:
ui = VideoData(ui)
ui = [ui[i] for i in range(len(ui))]
return ui
def draw_ui_element_safely(name, dtype, value, disabled=False): def draw_ui_element_safely(name, dtype, value, disabled=False):
if dtype == torch.dtype: if dtype == torch.dtype:
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16} option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
@@ -189,10 +209,10 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
option_map = {"cuda": "cuda", "cpu": "cpu"} option_map = {"cuda": "cuda", "cpu": "cpu"}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled) ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == bool: elif dtype == bool:
ui = st.checkbox(name, value, disabled=disabled) ui = st.checkbox(name, value=value, disabled=disabled)
elif dtype == ModelConfig: elif dtype == ModelConfig:
ui = draw_single_model_config(name, value, disabled=disabled) ui = draw_single_model_config(name, value=value, disabled=disabled)
elif dtype == list[ModelConfig]: elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]:
if name == "model_configs" and "model_configs_from_example" in st.session_state: if name == "model_configs" and "model_configs_from_example" in st.session_state:
model_configs = st.session_state["model_configs_from_example"] model_configs = st.session_state["model_configs_from_example"]
del st.session_state["model_configs_from_example"] del st.session_state["model_configs_from_example"]
@@ -201,20 +221,39 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
ui = draw_multi_model_config(name, disabled=disabled) ui = draw_multi_model_config(name, disabled=disabled)
elif dtype == str: elif dtype == str:
if "prompt" in name: if "prompt" in name:
ui = st.text_area(name, value, height=3, disabled=disabled) ui = st.text_area(name, value=value, height=3, disabled=disabled)
else: else:
ui = st.text_input(name, value, disabled=disabled) ui = st.text_input(name, value=value, disabled=disabled)
elif dtype == float: elif dtype == float:
ui = st.number_input(name, value, disabled=disabled) ui = st.number_input(name, value=value, disabled=disabled)
elif dtype == int: elif dtype == int:
ui = st.number_input(name, value, step=1, disabled=disabled) ui = st.number_input(name, value=value, step=1, disabled=disabled)
elif dtype == Image.Image: elif dtype == Image.Image:
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled) ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
if ui is not None: ui = Image.open(ui) if ui is not None: ui = Image.open(ui)
elif dtype == List[Image.Image]: elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]:
ui = draw_multi_images(name, value, disabled=disabled) if "video" in name:
elif dtype == List[ControlNetInput]: ui = draw_video(name, value=value, disabled=disabled)
ui = draw_controlnet_inputs(name, value, disabled=disabled) else:
ui = draw_multi_images(name, value=value, disabled=disabled)
elif dtype in [List[ControlNetInput], list[ControlNetInput]]:
ui = draw_controlnet_inputs(name, value=value, disabled=disabled)
elif dtype in [List[str], list[str]]:
ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled)
elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled)
elif dtype in [List[int], list[int]]:
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1})
elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]:
ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled)
elif dtype in [tuple[int, int], Tuple[int, int]]:
with st.container(border=True):
st.markdown(name)
ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled))
elif isinstance(dtype, typing._LiteralGenericAlias):
with st.container(border=True):
st.markdown(f"{name} ({dtype})")
ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden")
elif dtype is None: elif dtype is None:
if name == "progress_bar_cmd": if name == "progress_bar_cmd":
ui = value ui = value
@@ -259,7 +298,7 @@ def launch_webui():
with st.expander("Input", expanded=True): with st.expander("Input", expanded=True):
pipe = st.session_state["pipe"] pipe = st.session_state["pipe"]
input_params = {} input_params = {}
params = parse_params(pipe.__call__) params = parse_params(pipeline_class.__call__)
for param in params: for param in params:
if param["name"] in ["self"]: if param["name"] in ["self"]:
continue continue
@@ -280,4 +319,3 @@ def launch_webui():
print(f"unsupported result format: {result}") print(f"unsupported result format: {result}")
launch_webui() launch_webui()
# streamlit run examples/dev_tools/webui.py --server.fileWatcherType none