update UI

This commit is contained in:
Artiprocher
2024-08-21 16:57:56 +08:00
parent a6aaf9da2a
commit d6d14859e3
5 changed files with 247 additions and 1 deletions

View File

@@ -0,0 +1,235 @@
import gradio as gr
from diffsynth import ModelManager, SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
import os, torch
from PIL import Image
import numpy as np
config = {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"default_parameters": {
"height": 512,
"width": 512,
}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {}
},
"Stable Diffusion 3": {
"model_folder": "models/stable_diffusion_3",
"pipeline_class": SD3ImagePipeline,
"default_parameters": {}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
},
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"default_parameters": {}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"default_parameters": {
"cfg_scale": 1.0,
}
}
}
MAX_NUM_PAINTER_LAYERS = 8
def load_model_list(model_type):
if model_type is None:
return []
folder = config[model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
def load_model(model_type, model_path):
model_path = os.path.join(config[model_type]["model_folder"], model_path)
model_manager = ModelManager()
if model_type == "HunyuanDiT":
model_manager.load_models([
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
os.path.join(model_path, "mt5/pytorch_model.bin"),
os.path.join(model_path, "model/pytorch_model_ema.pt"),
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
])
elif model_type == "Kolors":
model_manager.load_models([
os.path.join(model_path, "text_encoder"),
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else:
model_manager.load_model(model_path)
pipe = config[model_type]["pipeline_class"].from_model_manager(model_manager)
return model_manager, pipe
model_manager: ModelManager = None
pipe = None
with gr.Blocks() as app:
gr.Markdown("# DiffSynth-Studio Painter")
with gr.Row():
with gr.Column(scale=382, min_width=100):
with gr.Accordion(label="Model"):
model_type = gr.Dropdown(choices=[i for i in config], label="Model type")
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
def model_type_to_model_path(model_type):
return gr.Dropdown(choices=load_model_list(model_type))
with gr.Accordion(label="Prompt"):
prompt = gr.Textbox(label="Prompt", lines=3)
negative_prompt = gr.Textbox(label="Negative prompt", lines=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale (only for FLUX)")
with gr.Accordion(label="Image"):
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
with gr.Column():
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width],
triggers=model_path.change
)
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
global model_manager, pipe
if isinstance(model_manager, ModelManager):
model_manager.to("cpu")
torch.cuda.empty_cache()
model_manager, pipe = load_model(model_type, model_path)
cfg_scale = config[model_type]["default_parameters"].get("cfg_scale", cfg_scale)
embedded_guidance = config[model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
num_inference_steps = config[model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
height = config[model_type]["default_parameters"].get("height", height)
width = config[model_type]["default_parameters"].get("width", width)
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
with gr.Column(scale=618, min_width=100):
with gr.Accordion(label="Painter"):
enable_local_prompt_list = []
local_prompt_list = []
mask_scale_list = []
canvas_list = []
for painter_layer_id in range(MAX_NUM_PAINTER_LAYERS):
with gr.Tab(label=f"Layer {painter_layer_id}"):
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
brush=gr.Brush(default_size=100, default_color="#000000", colors=["#000000"]),
label="Painter", key=f"canvas_{painter_layer_id}")
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
def resize_canvas(height, width, canvas):
h, w = canvas["background"].shape[:2]
if h != height or width != w:
return np.ones((height, width, 3), dtype=np.uint8) * 255
else:
return canvas
enable_local_prompt_list.append(enable_local_prompt)
local_prompt_list.append(local_prompt)
mask_scale_list.append(mask_scale)
canvas_list.append(canvas)
with gr.Accordion(label="Results"):
run_button = gr.Button(value="Generate", variant="primary")
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
with gr.Row():
with gr.Column():
output_to_painter_button = gr.Button(value="Set as painter's background")
with gr.Column():
output_to_input_button = gr.Button(value="Set as input image")
painter_background = gr.State(None)
input_background = gr.State(None)
@gr.on(
inputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
outputs=[output_image],
triggers=run_button.click
)
def generate_image(prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
global pipe
input_params = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
"progress_bar_cmd": progress.tqdm,
}
if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
args[0 * MAX_NUM_PAINTER_LAYERS: 1 * MAX_NUM_PAINTER_LAYERS],
args[1 * MAX_NUM_PAINTER_LAYERS: 2 * MAX_NUM_PAINTER_LAYERS],
args[2 * MAX_NUM_PAINTER_LAYERS: 3 * MAX_NUM_PAINTER_LAYERS],
args[3 * MAX_NUM_PAINTER_LAYERS: 4 * MAX_NUM_PAINTER_LAYERS]
)
local_prompts, masks, mask_scales = [], [], []
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
):
if enable_local_prompt:
local_prompts.append(local_prompt)
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
mask_scales.append(mask_scale)
input_params.update({
"local_prompts": local_prompts,
"masks": masks,
"mask_scales": mask_scales,
})
torch.manual_seed(seed)
image = pipe(**input_params)
return image
@gr.on(inputs=[output_image] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
def send_output_to_painter_background(output_image, *canvas_list):
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = output_image.resize((w, h))
return tuple(canvas_list)
app.launch()

View File

@@ -0,0 +1,15 @@
# Set web page format
import streamlit as st
st.set_page_config(layout="wide")
# Diasble virtual VRAM on windows system
import torch
torch.cuda.set_per_process_memory_fraction(0.999, 0)
st.markdown("""
# DiffSynth Studio
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
Welcome to DiffSynth Studio.
""")

View File

@@ -0,0 +1,362 @@
import torch, os, io, json, time
import numpy as np
from PIL import Image
import streamlit as st
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from diffsynth.models import ModelManager
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline, SD3ImagePipeline, HunyuanDiTImagePipeline, FluxImagePipeline
from diffsynth.data.video import crop_and_resize
config = {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion 3": {
"model_folder": "models/stable_diffusion_3",
"pipeline_class": SD3ImagePipeline,
"fixed_parameters": {}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
},
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"fixed_parameters": {}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"fixed_parameters": {
"height": 1024,
"width": 1024,
}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"fixed_parameters": {
"cfg_scale": 1.0,
}
}
}
def load_model_list(model_type):
folder = config[model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
def release_model():
if "model_manager" in st.session_state:
st.session_state["model_manager"].to("cpu")
del st.session_state["loaded_model_path"]
del st.session_state["model_manager"]
del st.session_state["pipeline"]
torch.cuda.empty_cache()
def load_model(model_type, model_path):
model_manager = ModelManager()
if model_type == "HunyuanDiT":
model_manager.load_models([
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
os.path.join(model_path, "mt5/pytorch_model.bin"),
os.path.join(model_path, "model/pytorch_model_ema.pt"),
os.path.join(model_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
])
elif model_type == "Kolors":
model_manager.load_models([
os.path.join(model_path, "text_encoder"),
os.path.join(model_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
else:
model_manager.load_model(model_path)
pipeline = config[model_type]["pipeline_class"].from_model_manager(model_manager)
st.session_state.loaded_model_path = model_path
st.session_state.model_manager = model_manager
st.session_state.pipeline = pipeline
return model_manager, pipeline
def use_output_image_as_input(update=True):
# Search for input image
output_image_id = 0
selected_output_image = None
while True:
if f"use_output_as_input_{output_image_id}" not in st.session_state:
break
if st.session_state[f"use_output_as_input_{output_image_id}"]:
selected_output_image = st.session_state["output_images"][output_image_id]
break
output_image_id += 1
if update and selected_output_image is not None:
st.session_state["input_image"] = selected_output_image
return selected_output_image is not None
def apply_stroke_to_image(stroke_image, image):
image = np.array(image.convert("RGB")).astype(np.float32)
height, width, _ = image.shape
stroke_image = np.array(Image.fromarray(stroke_image).resize((width, height))).astype(np.float32)
weight = stroke_image[:, :, -1:] / 255
stroke_image = stroke_image[:, :, :-1]
image = stroke_image * weight + image * (1 - weight)
image = np.clip(image, 0, 255).astype(np.uint8)
image = Image.fromarray(image)
return image
@st.cache_data
def image2bits(image):
image_byte = io.BytesIO()
image.save(image_byte, format="PNG")
image_byte = image_byte.getvalue()
return image_byte
def show_output_image(image):
st.image(image, use_column_width="always")
st.button("Use it as input image", key=f"use_output_as_input_{image_id}")
st.download_button("Download", data=image2bits(image), file_name="image.png", mime="image/png", key=f"download_output_{image_id}")
column_input, column_output = st.columns(2)
with st.sidebar:
# Select a model
with st.expander("Model", expanded=True):
model_type = st.selectbox("Model type", [model_type_ for model_type_ in config])
fixed_parameters = config[model_type]["fixed_parameters"]
model_path_list = ["None"] + load_model_list(model_type)
model_path = st.selectbox("Model path", model_path_list)
# Load the model
if model_path == "None":
# No models are selected. Release VRAM.
st.markdown("No models are selected.")
release_model()
else:
# A model is selected.
model_path = os.path.join(config[model_type]["model_folder"], model_path)
if st.session_state.get("loaded_model_path", "") != model_path:
# The loaded model is not the selected model. Reload it.
st.markdown(f"Loading model at {model_path}.")
st.markdown("Please wait a moment...")
release_model()
model_manager, pipeline = load_model(model_type, model_path)
st.markdown("Done.")
else:
# The loaded model is not the selected model. Fetch it from `st.session_state`.
st.markdown(f"Loading model at {model_path}.")
st.markdown("Please wait a moment...")
model_manager, pipeline = st.session_state.model_manager, st.session_state.pipeline
st.markdown("Done.")
# Show parameters
with st.expander("Prompt", expanded=True):
prompt = st.text_area("Positive prompt")
if "negative_prompt" in fixed_parameters:
negative_prompt = fixed_parameters["negative_prompt"]
else:
negative_prompt = st.text_area("Negative prompt")
if "cfg_scale" in fixed_parameters:
cfg_scale = fixed_parameters["cfg_scale"]
else:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.5)
with st.expander("Image", expanded=True):
if "num_inference_steps" in fixed_parameters:
num_inference_steps = fixed_parameters["num_inference_steps"]
else:
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=20)
if "height" in fixed_parameters:
height = fixed_parameters["height"]
else:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 2048], value=512)
if "width" in fixed_parameters:
width = fixed_parameters["width"]
else:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 2048], value=512)
num_images = st.number_input("Number of images", value=2)
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
if use_fixed_seed:
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
# Other fixed parameters
denoising_strength = 1.0
repetition = 1
# Show input image
with column_input:
with st.expander("Input image (Optional)", expanded=True):
with st.container(border=True):
column_white_board, column_upload_image = st.columns([1, 2])
with column_white_board:
create_white_board = st.button("Create white board")
delete_input_image = st.button("Delete input image")
with column_upload_image:
upload_image = st.file_uploader("Upload image", type=["png", "jpg"], key="upload_image")
if upload_image is not None:
st.session_state["input_image"] = crop_and_resize(Image.open(upload_image), height, width)
elif create_white_board:
st.session_state["input_image"] = Image.fromarray(np.ones((height, width, 3), dtype=np.uint8) * 255)
else:
use_output_image_as_input()
if delete_input_image and "input_image" in st.session_state:
del st.session_state.input_image
if delete_input_image and "upload_image" in st.session_state:
del st.session_state.upload_image
input_image = st.session_state.get("input_image", None)
if input_image is not None:
with st.container(border=True):
column_drawing_mode, column_color_1, column_color_2 = st.columns([4, 1, 1])
with column_drawing_mode:
drawing_mode = st.radio("Drawing tool", ["transform", "freedraw", "line", "rect"], horizontal=True, index=1)
with column_color_1:
stroke_color = st.color_picker("Stroke color")
with column_color_2:
fill_color = st.color_picker("Fill color")
stroke_width = st.slider("Stroke width", min_value=1, max_value=50, value=10)
with st.container(border=True):
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=0.7)
repetition = st.slider("Repetition", min_value=1, max_value=8, value=1)
with st.container(border=True):
input_width, input_height = input_image.size
canvas_result = st_canvas(
fill_color=fill_color,
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color="rgba(255, 255, 255, 0)",
background_image=input_image,
update_streamlit=True,
height=int(512 / input_width * input_height),
width=512,
drawing_mode=drawing_mode,
key="canvas"
)
num_painter_layer = st.number_input("Number of painter layers", min_value=0, max_value=10, step=1, value=0)
local_prompts, masks, mask_scales = [], [], []
white_board = Image.fromarray(np.ones((512, 512, 3), dtype=np.uint8) * 255)
painter_layers_json_data = []
for painter_tab_id in range(num_painter_layer):
with st.expander(f"Painter layer {painter_tab_id}", expanded=True):
enable_local_prompt = st.checkbox(f"Enable prompt {painter_tab_id}", value=True)
local_prompt = st.text_area(f"Prompt {painter_tab_id}")
mask_scale = st.slider(f"Mask scale {painter_tab_id}", min_value=0.0, max_value=3.0, value=1.0)
stroke_width = st.slider(f"Stroke width {painter_tab_id}", min_value=1, max_value=300, value=100)
canvas_result_local = st_canvas(
fill_color="#000000",
stroke_width=stroke_width,
stroke_color="#000000",
background_color="rgba(255, 255, 255, 0)",
background_image=white_board,
update_streamlit=True,
height=512,
width=512,
drawing_mode="freedraw",
key=f"canvas_{painter_tab_id}"
)
if canvas_result_local.json_data is not None:
painter_layers_json_data.append(canvas_result_local.json_data.copy())
painter_layers_json_data[-1]["prompt"] = local_prompt
if enable_local_prompt:
local_prompts.append(local_prompt)
if canvas_result_local.image_data is not None:
mask = apply_stroke_to_image(canvas_result_local.image_data, white_board)
else:
mask = white_board
mask = Image.fromarray(255 - np.array(mask))
masks.append(mask)
mask_scales.append(mask_scale)
save_painter_layers = st.button("Save painter layers")
if save_painter_layers:
os.makedirs("data/painter_layers", exist_ok=True)
json_file_path = f"data/painter_layers/{time.time_ns()}.json"
with open(json_file_path, "w") as f:
json.dump(painter_layers_json_data, f, indent=4)
st.markdown(f"Painter layers are saved in {json_file_path}.")
with column_output:
run_button = st.button("Generate image", type="primary")
auto_update = st.checkbox("Auto update", value=False)
num_image_columns = st.slider("Columns", min_value=1, max_value=8, value=2)
image_columns = st.columns(num_image_columns)
# Run
if (run_button or auto_update) and model_path != "None":
if input_image is not None:
input_image = input_image.resize((width, height))
if canvas_result.image_data is not None:
input_image = apply_stroke_to_image(canvas_result.image_data, input_image)
output_images = []
for image_id in range(num_images * repetition):
if use_fixed_seed:
torch.manual_seed(seed + image_id)
else:
torch.manual_seed(np.random.randint(0, 10**9))
if image_id >= num_images:
input_image = output_images[image_id - num_images]
with image_columns[image_id % num_image_columns]:
progress_bar_st = st.progress(0.0)
image = pipeline(
prompt, negative_prompt=negative_prompt,
local_prompts=local_prompts, masks=masks, mask_scales=mask_scales,
cfg_scale=cfg_scale, num_inference_steps=num_inference_steps,
height=height, width=width,
input_image=input_image, denoising_strength=denoising_strength,
progress_bar_st=progress_bar_st
)
output_images.append(image)
progress_bar_st.progress(1.0)
show_output_image(image)
st.session_state["output_images"] = output_images
elif "output_images" in st.session_state:
for image_id in range(len(st.session_state.output_images)):
with image_columns[image_id % num_image_columns]:
image = st.session_state.output_images[image_id]
progress_bar = st.progress(1.0)
show_output_image(image)
if "upload_image" in st.session_state and use_output_image_as_input(update=False):
st.markdown("If you want to use an output image as input image, please delete the uploaded image manually.")

View File

@@ -0,0 +1,197 @@
import streamlit as st
st.set_page_config(layout="wide")
from diffsynth import SDVideoPipelineRunner
import os
import numpy as np
def load_model_list(folder):
file_list = os.listdir(folder)
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
file_list = sorted(file_list)
return file_list
def match_processor_id(model_name, supported_processor_id_list):
sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
for processor_id in sorted_processor_id:
if processor_id in model_name:
return supported_processor_id_list.index(processor_id) + 1
return 0
config = {
"models": {
"model_list": [],
"textual_inversion_folder": "models/textual_inversion",
"device": "cuda",
"lora_alphas": [],
"controlnet_units": []
},
"data": {
"input_frames": None,
"controlnet_frames": [],
"output_folder": "output",
"fps": 60
},
"pipeline": {
"seed": 0,
"pipeline_inputs": {}
}
}
with st.expander("Model", expanded=True):
stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
if stable_diffusion_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
if animatediff_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
column_lora, column_lora_alpha = st.columns([2, 1])
with column_lora:
sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
with column_lora_alpha:
lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
if sd_lora_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
config["models"]["lora_alphas"].append(lora_alpha)
with st.expander("Data", expanded=True):
with st.container(border=True):
input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
with column_height:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
with column_width:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
with column_start_frame_index:
start_frame_id = st.number_input("Start Frame id", value=0)
with column_end_frame_index:
end_frame_id = st.number_input("End Frame id", value=16)
if input_video != "":
config["data"]["input_frames"] = {
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
}
with st.container(border=True):
output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
fps = st.number_input("FPS", value=60)
config["data"]["output_folder"] = output_video
config["data"]["fps"] = fps
with st.expander("ControlNet Units", expanded=True):
supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
for controlnet_id in range(len(controlnet_units)):
with controlnet_units[controlnet_id]:
controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
key=f"controlnet_ckpt_{controlnet_id}")
processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
disabled=controlnet_ckpt == "None",
key=f"use_input_video_as_controlnet_input_{controlnet_id}")
if not use_input_video_as_controlnet_input:
controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
with column_height:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
with column_width:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
with column_start_frame_index:
start_frame_id = st.number_input("Start Frame id", value=0,
disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
with column_end_frame_index:
end_frame_id = st.number_input("End Frame id", value=16,
disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
if input_video != "":
config["data"]["input_video"] = {
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
}
if controlnet_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
config["models"]["controlnet_units"].append({
"processor_id": processor_id,
"model_path": os.path.join("models/ControlNet", controlnet_ckpt),
"scale": controlnet_scale,
})
if use_input_video_as_controlnet_input:
config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
else:
config["data"]["controlnet_frames"].append({
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
})
with st.container(border=True):
with st.expander("Seed", expanded=True):
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
if use_fixed_seed:
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
else:
seed = np.random.randint(0, 10**9)
with st.expander("Textual Guidance", expanded=True):
prompt = st.text_area("Positive prompt")
negative_prompt = st.text_area("Negative prompt")
column_cfg_scale, column_clip_skip = st.columns(2)
with column_cfg_scale:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
with column_clip_skip:
clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
with st.expander("Denoising", expanded=True):
column_num_inference_steps, column_denoising_strength = st.columns(2)
with column_num_inference_steps:
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
with column_denoising_strength:
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
with st.expander("Efficiency", expanded=False):
animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
animatediff_stride = st.slider("Animatediff stride",
min_value=1,
max_value=max(2, animatediff_batch_size),
value=max(1, animatediff_batch_size // 2),
step=1)
unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
config["pipeline"]["seed"] = seed
config["pipeline"]["pipeline_inputs"] = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"clip_skip": clip_skip,
"denoising_strength": denoising_strength,
"num_inference_steps": num_inference_steps,
"animatediff_batch_size": animatediff_batch_size,
"animatediff_stride": animatediff_stride,
"unet_batch_size": unet_batch_size,
"controlnet_batch_size": controlnet_batch_size,
"cross_frame_attention": cross_frame_attention,
}
run_button = st.button("Run☢", type="primary")
if run_button:
SDVideoPipelineRunner(in_streamlit=True).run(config)