diff --git a/diffsynth/pipelines/anima_image.py b/diffsynth/pipelines/anima_image.py index bc4f6cd..e29b6f9 100644 --- a/diffsynth/pipelines/anima_image.py +++ b/diffsynth/pipelines/anima_image.py @@ -74,7 +74,7 @@ class AnimaImagePipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", cfg_scale: float = 4.0, # Image diff --git a/diffsynth/pipelines/flux2_image.py b/diffsynth/pipelines/flux2_image.py index 4e837e9..fae29b6 100644 --- a/diffsynth/pipelines/flux2_image.py +++ b/diffsynth/pipelines/flux2_image.py @@ -75,7 +75,7 @@ class Flux2ImagePipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", cfg_scale: float = 1.0, embedded_guidance: float = 4.0, diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index fdace11..705012a 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -181,7 +181,7 @@ class FluxImagePipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", cfg_scale: float = 1.0, embedded_guidance: float = 3.5, @@ -199,10 +199,6 @@ class FluxImagePipeline(BasePipeline): sigma_shift: float = None, # Steps num_inference_steps: int = 30, - # local prompts - multidiffusion_prompts:tuple[str] =(), - multidiffusion_masks:tuple[str]=(), - multidiffusion_scales:tuple[str]=(), # Kontext kontext_images: Union[list[Image.Image], Image.Image] = None, # ControlNet @@ -257,7 +253,6 @@ class FluxImagePipeline(BasePipeline): "height": height, "width": width, "seed": seed, "rand_device": rand_device, "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps, - "multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales, "kontext_images": kontext_images, "controlnet_inputs": controlnet_inputs, "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale, diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index af31e08..10723e3 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -169,7 +169,7 @@ class LTX2AudioVideoPipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", denoising_strength: float = 1.0, # Image-to-video diff --git a/diffsynth/pipelines/mova_audio_video.py b/diffsynth/pipelines/mova_audio_video.py index 9475e16..aa83440 100644 --- a/diffsynth/pipelines/mova_audio_video.py +++ b/diffsynth/pipelines/mova_audio_video.py @@ -115,7 +115,7 @@ class MovaAudioVideoPipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", # Image-to-video input_image: Image.Image = None, diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index f3256a1..ad6bb31 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -100,7 +100,7 @@ class QwenImagePipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", cfg_scale: float = 4.0, # Image diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 9022b13..a0a8f7e 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -190,7 +190,7 @@ class WanVideoPipeline(BasePipeline): def __call__( self, # Prompt - prompt: str, + prompt: str = "", negative_prompt: str = "", # Image-to-video input_image: Image.Image = None, diff --git a/examples/dev_tools/webui.py b/examples/dev_tools/webui.py index c133b06..8ce2be5 100644 --- a/examples/dev_tools/webui.py +++ b/examples/dev_tools/webui.py @@ -1,7 +1,7 @@ -import importlib, inspect, pkgutil, traceback, torch, os, re -from typing import Union, List, Optional, Tuple, Iterable, Dict +import importlib, inspect, pkgutil, traceback, torch, os, re, typing +from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal from contextlib import contextmanager - +from diffsynth.utils.data import VideoData import streamlit as st from diffsynth import ModelConfig from diffsynth.diffusion.base_pipeline import ControlNetInput @@ -141,6 +141,18 @@ def draw_multi_images(name="", value=None, disabled=False): if image is not None: images.append(Image.open(image)) return images +def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None): + if kwargs is None: + kwargs = {} + elements = [] + with st.container(border=True): + st.markdown(name) + num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled) + for i in range(num): + element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs) + elements.append(element) + return elements + def draw_controlnet_input(name="", value=None, disabled=False): with st.container(border=True): st.markdown(name) @@ -174,7 +186,7 @@ def draw_ui_element(name, dtype, value): if value is None: with st.container(border=True): enable = st.checkbox(f"Enable {name}", value=False) - ui = draw_ui_element_safely(name, dtype, value, disabled=not enable) + ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable) if enable: return ui else: @@ -182,6 +194,13 @@ def draw_ui_element(name, dtype, value): else: return draw_ui_element_safely(name, dtype, value) +def draw_video(name, value=None, disabled=False): + ui = st.file_uploader(name, type=["mp4"], disabled=disabled) + if ui is not None: + ui = VideoData(ui) + ui = [ui[i] for i in range(len(ui))] + return ui + def draw_ui_element_safely(name, dtype, value, disabled=False): if dtype == torch.dtype: option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16} @@ -190,10 +209,10 @@ def draw_ui_element_safely(name, dtype, value, disabled=False): option_map = {"cuda": "cuda", "cpu": "cpu"} ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled) elif dtype == bool: - ui = st.checkbox(name, value, disabled=disabled) + ui = st.checkbox(name, value=value, disabled=disabled) elif dtype == ModelConfig: - ui = draw_single_model_config(name, value, disabled=disabled) - elif dtype == list[ModelConfig]: + ui = draw_single_model_config(name, value=value, disabled=disabled) + elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]: if name == "model_configs" and "model_configs_from_example" in st.session_state: model_configs = st.session_state["model_configs_from_example"] del st.session_state["model_configs_from_example"] @@ -202,20 +221,39 @@ def draw_ui_element_safely(name, dtype, value, disabled=False): ui = draw_multi_model_config(name, disabled=disabled) elif dtype == str: if "prompt" in name: - ui = st.text_area(name, value, height=3, disabled=disabled) + ui = st.text_area(name, value=value, height=3, disabled=disabled) else: - ui = st.text_input(name, value, disabled=disabled) + ui = st.text_input(name, value=value, disabled=disabled) elif dtype == float: - ui = st.number_input(name, value, disabled=disabled) + ui = st.number_input(name, value=value, disabled=disabled) elif dtype == int: - ui = st.number_input(name, value, step=1, disabled=disabled) + ui = st.number_input(name, value=value, step=1, disabled=disabled) elif dtype == Image.Image: ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled) if ui is not None: ui = Image.open(ui) - elif dtype == List[Image.Image]: - ui = draw_multi_images(name, value, disabled=disabled) - elif dtype == List[ControlNetInput]: - ui = draw_controlnet_inputs(name, value, disabled=disabled) + elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]: + if "video" in name: + ui = draw_video(name, value=value, disabled=disabled) + else: + ui = draw_multi_images(name, value=value, disabled=disabled) + elif dtype in [List[ControlNetInput], list[ControlNetInput]]: + ui = draw_controlnet_inputs(name, value=value, disabled=disabled) + elif dtype in [List[str], list[str]]: + ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled) + elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]: + ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled) + elif dtype in [List[int], list[int]]: + ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1}) + elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]: + ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled) + elif dtype in [tuple[int, int], Tuple[int, int]]: + with st.container(border=True): + st.markdown(name) + ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled)) + elif isinstance(dtype, typing._LiteralGenericAlias): + with st.container(border=True): + st.markdown(f"{name} ({dtype})") + ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden") elif dtype is None: if name == "progress_bar_cmd": ui = value @@ -260,7 +298,7 @@ def launch_webui(): with st.expander("Input", expanded=True): pipe = st.session_state["pipe"] input_params = {} - params = parse_params(pipe.__call__) + params = parse_params(pipeline_class.__call__) for param in params: if param["name"] in ["self"]: continue