Files
DiffSynth-Studio/apps/streamlit/pages/2_Video_Creator.py
Artiprocher d6d14859e3 update UI
2024-08-21 16:57:56 +08:00

198 lines
10 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import streamlit as st
st.set_page_config(layout="wide")
from diffsynth import SDVideoPipelineRunner
import os
import numpy as np
def load_model_list(folder):
file_list = os.listdir(folder)
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
file_list = sorted(file_list)
return file_list
def match_processor_id(model_name, supported_processor_id_list):
sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
for processor_id in sorted_processor_id:
if processor_id in model_name:
return supported_processor_id_list.index(processor_id) + 1
return 0
config = {
"models": {
"model_list": [],
"textual_inversion_folder": "models/textual_inversion",
"device": "cuda",
"lora_alphas": [],
"controlnet_units": []
},
"data": {
"input_frames": None,
"controlnet_frames": [],
"output_folder": "output",
"fps": 60
},
"pipeline": {
"seed": 0,
"pipeline_inputs": {}
}
}
with st.expander("Model", expanded=True):
stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
if stable_diffusion_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
if animatediff_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
column_lora, column_lora_alpha = st.columns([2, 1])
with column_lora:
sd_lora_ckpt = st.selectbox("LoRA", ["None"] + load_model_list("models/lora"))
with column_lora_alpha:
lora_alpha = st.slider("LoRA Alpha", min_value=-4.0, max_value=4.0, value=1.0, step=0.1)
if sd_lora_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/lora", sd_lora_ckpt))
config["models"]["lora_alphas"].append(lora_alpha)
with st.expander("Data", expanded=True):
with st.container(border=True):
input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
with column_height:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
with column_width:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
with column_start_frame_index:
start_frame_id = st.number_input("Start Frame id", value=0)
with column_end_frame_index:
end_frame_id = st.number_input("End Frame id", value=16)
if input_video != "":
config["data"]["input_frames"] = {
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
}
with st.container(border=True):
output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
fps = st.number_input("FPS", value=60)
config["data"]["output_folder"] = output_video
config["data"]["fps"] = fps
with st.expander("ControlNet Units", expanded=True):
supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
for controlnet_id in range(len(controlnet_units)):
with controlnet_units[controlnet_id]:
controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
key=f"controlnet_ckpt_{controlnet_id}")
processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
disabled=controlnet_ckpt == "None",
key=f"use_input_video_as_controlnet_input_{controlnet_id}")
if not use_input_video_as_controlnet_input:
controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
with column_height:
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
with column_width:
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
with column_start_frame_index:
start_frame_id = st.number_input("Start Frame id", value=0,
disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
with column_end_frame_index:
end_frame_id = st.number_input("End Frame id", value=16,
disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
if input_video != "":
config["data"]["input_video"] = {
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
}
if controlnet_ckpt != "None":
config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
config["models"]["controlnet_units"].append({
"processor_id": processor_id,
"model_path": os.path.join("models/ControlNet", controlnet_ckpt),
"scale": controlnet_scale,
})
if use_input_video_as_controlnet_input:
config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
else:
config["data"]["controlnet_frames"].append({
"video_file": input_video,
"image_folder": None,
"height": height,
"width": width,
"start_frame_id": start_frame_id,
"end_frame_id": end_frame_id
})
with st.container(border=True):
with st.expander("Seed", expanded=True):
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
if use_fixed_seed:
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
else:
seed = np.random.randint(0, 10**9)
with st.expander("Textual Guidance", expanded=True):
prompt = st.text_area("Positive prompt")
negative_prompt = st.text_area("Negative prompt")
column_cfg_scale, column_clip_skip = st.columns(2)
with column_cfg_scale:
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
with column_clip_skip:
clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
with st.expander("Denoising", expanded=True):
column_num_inference_steps, column_denoising_strength = st.columns(2)
with column_num_inference_steps:
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
with column_denoising_strength:
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
with st.expander("Efficiency", expanded=False):
animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
animatediff_stride = st.slider("Animatediff stride",
min_value=1,
max_value=max(2, animatediff_batch_size),
value=max(1, animatediff_batch_size // 2),
step=1)
unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
config["pipeline"]["seed"] = seed
config["pipeline"]["pipeline_inputs"] = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"clip_skip": clip_skip,
"denoising_strength": denoising_strength,
"num_inference_steps": num_inference_steps,
"animatediff_batch_size": animatediff_batch_size,
"animatediff_stride": animatediff_stride,
"unet_batch_size": unet_batch_size,
"controlnet_batch_size": controlnet_batch_size,
"cross_frame_attention": cross_frame_attention,
}
run_button = st.button("Run☢", type="primary")
if run_button:
SDVideoPipelineRunner(in_streamlit=True).run(config)