mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-14 21:58:17 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2a04139b4 | ||
|
|
db0f1571b1 |
@@ -74,7 +74,7 @@ class AnimaImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str,
|
prompt: str = "",
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
cfg_scale: float = 4.0,
|
cfg_scale: float = 4.0,
|
||||||
# Image
|
# Image
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class Flux2ImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str,
|
prompt: str = "",
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
cfg_scale: float = 1.0,
|
cfg_scale: float = 1.0,
|
||||||
embedded_guidance: float = 4.0,
|
embedded_guidance: float = 4.0,
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str,
|
prompt: str = "",
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
cfg_scale: float = 1.0,
|
cfg_scale: float = 1.0,
|
||||||
embedded_guidance: float = 3.5,
|
embedded_guidance: float = 3.5,
|
||||||
@@ -199,10 +199,6 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
sigma_shift: float = None,
|
sigma_shift: float = None,
|
||||||
# Steps
|
# Steps
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 30,
|
||||||
# local prompts
|
|
||||||
multidiffusion_prompts:tuple[str] =(),
|
|
||||||
multidiffusion_masks:tuple[str]=(),
|
|
||||||
multidiffusion_scales:tuple[str]=(),
|
|
||||||
# Kontext
|
# Kontext
|
||||||
kontext_images: Union[list[Image.Image], Image.Image] = None,
|
kontext_images: Union[list[Image.Image], Image.Image] = None,
|
||||||
# ControlNet
|
# ControlNet
|
||||||
@@ -257,7 +253,6 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
"height": height, "width": width,
|
"height": height, "width": width,
|
||||||
"seed": seed, "rand_device": rand_device,
|
"seed": seed, "rand_device": rand_device,
|
||||||
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
|
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
|
||||||
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
|
|
||||||
"kontext_images": kontext_images,
|
"kontext_images": kontext_images,
|
||||||
"controlnet_inputs": controlnet_inputs,
|
"controlnet_inputs": controlnet_inputs,
|
||||||
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,
|
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str,
|
prompt: str = "",
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
denoising_strength: float = 1.0,
|
denoising_strength: float = 1.0,
|
||||||
# Image-to-video
|
# Image-to-video
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class MovaAudioVideoPipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str,
|
prompt: str = "",
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
# Image-to-video
|
# Image-to-video
|
||||||
input_image: Image.Image = None,
|
input_image: Image.Image = None,
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str,
|
prompt: str = "",
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
cfg_scale: float = 4.0,
|
cfg_scale: float = 4.0,
|
||||||
# Image
|
# Image
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
# Prompt
|
# Prompt
|
||||||
prompt: str,
|
prompt: str = "",
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
# Image-to-video
|
# Image-to-video
|
||||||
input_image: Image.Image = None,
|
input_image: Image.Image = None,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import importlib, inspect, pkgutil, traceback, torch, os, re
|
import importlib, inspect, pkgutil, traceback, torch, os, re, typing
|
||||||
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
from typing import Union, List, Optional, Tuple, Iterable, Dict, Literal
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from diffsynth.utils.data import VideoData
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from diffsynth import ModelConfig
|
from diffsynth import ModelConfig
|
||||||
from diffsynth.diffusion.base_pipeline import ControlNetInput
|
from diffsynth.diffusion.base_pipeline import ControlNetInput
|
||||||
@@ -141,6 +141,18 @@ def draw_multi_images(name="", value=None, disabled=False):
|
|||||||
if image is not None: images.append(Image.open(image))
|
if image is not None: images.append(Image.open(image))
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
def draw_multi_elements(st_element, name="", value=None, disabled=False, kwargs=None):
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
elements = []
|
||||||
|
with st.container(border=True):
|
||||||
|
st.markdown(name)
|
||||||
|
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
|
||||||
|
for i in range(num):
|
||||||
|
element = st_element(name, key=f"{name}_{i}", disabled=disabled, value=None if value is None else value[i], **kwargs)
|
||||||
|
elements.append(element)
|
||||||
|
return elements
|
||||||
|
|
||||||
def draw_controlnet_input(name="", value=None, disabled=False):
|
def draw_controlnet_input(name="", value=None, disabled=False):
|
||||||
with st.container(border=True):
|
with st.container(border=True):
|
||||||
st.markdown(name)
|
st.markdown(name)
|
||||||
@@ -174,7 +186,7 @@ def draw_ui_element(name, dtype, value):
|
|||||||
if value is None:
|
if value is None:
|
||||||
with st.container(border=True):
|
with st.container(border=True):
|
||||||
enable = st.checkbox(f"Enable {name}", value=False)
|
enable = st.checkbox(f"Enable {name}", value=False)
|
||||||
ui = draw_ui_element_safely(name, dtype, value, disabled=not enable)
|
ui = draw_ui_element_safely(name, dtype, value=value, disabled=not enable)
|
||||||
if enable:
|
if enable:
|
||||||
return ui
|
return ui
|
||||||
else:
|
else:
|
||||||
@@ -182,6 +194,13 @@ def draw_ui_element(name, dtype, value):
|
|||||||
else:
|
else:
|
||||||
return draw_ui_element_safely(name, dtype, value)
|
return draw_ui_element_safely(name, dtype, value)
|
||||||
|
|
||||||
|
def draw_video(name, value=None, disabled=False):
|
||||||
|
ui = st.file_uploader(name, type=["mp4"], disabled=disabled)
|
||||||
|
if ui is not None:
|
||||||
|
ui = VideoData(ui)
|
||||||
|
ui = [ui[i] for i in range(len(ui))]
|
||||||
|
return ui
|
||||||
|
|
||||||
def draw_ui_element_safely(name, dtype, value, disabled=False):
|
def draw_ui_element_safely(name, dtype, value, disabled=False):
|
||||||
if dtype == torch.dtype:
|
if dtype == torch.dtype:
|
||||||
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||||
@@ -190,10 +209,10 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
|
|||||||
option_map = {"cuda": "cuda", "cpu": "cpu"}
|
option_map = {"cuda": "cuda", "cpu": "cpu"}
|
||||||
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
|
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
|
||||||
elif dtype == bool:
|
elif dtype == bool:
|
||||||
ui = st.checkbox(name, value, disabled=disabled)
|
ui = st.checkbox(name, value=value, disabled=disabled)
|
||||||
elif dtype == ModelConfig:
|
elif dtype == ModelConfig:
|
||||||
ui = draw_single_model_config(name, value, disabled=disabled)
|
ui = draw_single_model_config(name, value=value, disabled=disabled)
|
||||||
elif dtype == list[ModelConfig]:
|
elif dtype in [list[ModelConfig], List[ModelConfig], Union[list[ModelConfig], ModelConfig, str]]:
|
||||||
if name == "model_configs" and "model_configs_from_example" in st.session_state:
|
if name == "model_configs" and "model_configs_from_example" in st.session_state:
|
||||||
model_configs = st.session_state["model_configs_from_example"]
|
model_configs = st.session_state["model_configs_from_example"]
|
||||||
del st.session_state["model_configs_from_example"]
|
del st.session_state["model_configs_from_example"]
|
||||||
@@ -202,20 +221,39 @@ def draw_ui_element_safely(name, dtype, value, disabled=False):
|
|||||||
ui = draw_multi_model_config(name, disabled=disabled)
|
ui = draw_multi_model_config(name, disabled=disabled)
|
||||||
elif dtype == str:
|
elif dtype == str:
|
||||||
if "prompt" in name:
|
if "prompt" in name:
|
||||||
ui = st.text_area(name, value, height=3, disabled=disabled)
|
ui = st.text_area(name, value=value, height=3, disabled=disabled)
|
||||||
else:
|
else:
|
||||||
ui = st.text_input(name, value, disabled=disabled)
|
ui = st.text_input(name, value=value, disabled=disabled)
|
||||||
elif dtype == float:
|
elif dtype == float:
|
||||||
ui = st.number_input(name, value, disabled=disabled)
|
ui = st.number_input(name, value=value, disabled=disabled)
|
||||||
elif dtype == int:
|
elif dtype == int:
|
||||||
ui = st.number_input(name, value, step=1, disabled=disabled)
|
ui = st.number_input(name, value=value, step=1, disabled=disabled)
|
||||||
elif dtype == Image.Image:
|
elif dtype == Image.Image:
|
||||||
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
|
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
|
||||||
if ui is not None: ui = Image.open(ui)
|
if ui is not None: ui = Image.open(ui)
|
||||||
elif dtype == List[Image.Image]:
|
elif dtype in [List[Image.Image], list[Image.Image], Union[list[Image.Image], Image.Image], Union[List[Image.Image], Image.Image]]:
|
||||||
ui = draw_multi_images(name, value, disabled=disabled)
|
if "video" in name:
|
||||||
elif dtype == List[ControlNetInput]:
|
ui = draw_video(name, value=value, disabled=disabled)
|
||||||
ui = draw_controlnet_inputs(name, value, disabled=disabled)
|
else:
|
||||||
|
ui = draw_multi_images(name, value=value, disabled=disabled)
|
||||||
|
elif dtype in [List[ControlNetInput], list[ControlNetInput]]:
|
||||||
|
ui = draw_controlnet_inputs(name, value=value, disabled=disabled)
|
||||||
|
elif dtype in [List[str], list[str]]:
|
||||||
|
ui = draw_multi_elements(st.text_input, name, value=value, disabled=disabled)
|
||||||
|
elif dtype in [List[float], list[float], Union[list[float], float], Union[List[float], float]]:
|
||||||
|
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled)
|
||||||
|
elif dtype in [List[int], list[int]]:
|
||||||
|
ui = draw_multi_elements(st.number_input, name, value=value, disabled=disabled, kwargs={"step": 1})
|
||||||
|
elif dtype in [List[List[Image.Image]], list[list[Image.Image]]]:
|
||||||
|
ui = draw_multi_elements(draw_video, name, value=value, disabled=disabled)
|
||||||
|
elif dtype in [tuple[int, int], Tuple[int, int]]:
|
||||||
|
with st.container(border=True):
|
||||||
|
st.markdown(name)
|
||||||
|
ui = (st.text_input(f"{name}_0", value=value[0], disabled=disabled), st.text_input(f"{name}_1", value=value[1], disabled=disabled))
|
||||||
|
elif isinstance(dtype, typing._LiteralGenericAlias):
|
||||||
|
with st.container(border=True):
|
||||||
|
st.markdown(f"{name} ({dtype})")
|
||||||
|
ui = st.text_input(name, value=value, disabled=disabled, label_visibility="hidden")
|
||||||
elif dtype is None:
|
elif dtype is None:
|
||||||
if name == "progress_bar_cmd":
|
if name == "progress_bar_cmd":
|
||||||
ui = value
|
ui = value
|
||||||
@@ -260,7 +298,7 @@ def launch_webui():
|
|||||||
with st.expander("Input", expanded=True):
|
with st.expander("Input", expanded=True):
|
||||||
pipe = st.session_state["pipe"]
|
pipe = st.session_state["pipe"]
|
||||||
input_params = {}
|
input_params = {}
|
||||||
params = parse_params(pipe.__call__)
|
params = parse_params(pipeline_class.__call__)
|
||||||
for param in params:
|
for param in params:
|
||||||
if param["name"] in ["self"]:
|
if param["name"] in ["self"]:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user