mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
add video UI
This commit is contained in:
@@ -12,6 +12,8 @@ Create Python environment:
|
||||
conda env create -f environment.yml
|
||||
```
|
||||
|
||||
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
|
||||
|
||||
Enter the Python environment:
|
||||
|
||||
```
|
||||
|
||||
@@ -23,13 +23,11 @@ class MultiControlNetManager:
|
||||
self.models = [unit.model for unit in controlnet_units]
|
||||
self.scales = [unit.scale for unit in controlnet_units]
|
||||
|
||||
def process_image(self, image, return_image=False):
|
||||
processed_image = [
|
||||
processor(image)
|
||||
for processor in self.processors
|
||||
]
|
||||
if return_image:
|
||||
return processed_image
|
||||
def process_image(self, image, processor_id=None):
|
||||
if processor_id is None:
|
||||
processed_image = [processor(image) for processor in self.processors]
|
||||
else:
|
||||
processed_image = [self.processors[processor_id](image)]
|
||||
processed_image = torch.concat([
|
||||
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
||||
for image_ in processed_image
|
||||
|
||||
@@ -16,15 +16,15 @@ class Annotator:
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path)
|
||||
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path)
|
||||
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path)
|
||||
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path)
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path)
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
|
||||
elif processor_id == "tile":
|
||||
self.processor = None
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,7 @@ import cupy as cp
|
||||
class FastBlendSmoother:
|
||||
def __init__(self):
|
||||
self.batch_size = 8
|
||||
self.window_size = 32
|
||||
self.window_size = 64
|
||||
self.ebsynth_config = {
|
||||
"minimum_patch_size": 5,
|
||||
"threads_per_block": 8,
|
||||
|
||||
@@ -189,6 +189,7 @@ class ModelManager:
|
||||
model.to(device)
|
||||
else:
|
||||
self.model[component].to(device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_model_with_model_path(self, model_path):
|
||||
for component in self.model_path:
|
||||
|
||||
@@ -51,6 +51,9 @@ def lets_dance_with_long_video(
|
||||
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
||||
hidden_states_output[i] = (hidden_states, num + 1)
|
||||
|
||||
if batch_id_ == num_frames:
|
||||
break
|
||||
|
||||
# output
|
||||
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
||||
return hidden_states
|
||||
@@ -195,10 +198,21 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
|
||||
# Prepare ControlNets
|
||||
if controlnet_frames is not None:
|
||||
controlnet_frames = torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||
], dim=1)
|
||||
if isinstance(controlnet_frames[0], list):
|
||||
controlnet_frames_ = []
|
||||
for processor_id in range(len(controlnet_frames)):
|
||||
controlnet_frames_.append(
|
||||
torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
||||
], dim=1)
|
||||
)
|
||||
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
||||
else:
|
||||
controlnet_frames = torch.stack([
|
||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||
], dim=1)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
|
||||
@@ -8,6 +8,7 @@ dependencies:
|
||||
- pip=23.0.1
|
||||
- cudatoolkit
|
||||
- pytorch
|
||||
- cupy
|
||||
- pip:
|
||||
- transformers
|
||||
- controlnet-aux==0.0.7
|
||||
|
||||
@@ -2,7 +2,6 @@ import torch, os, io
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
from streamlit_drawable_canvas import st_canvas
|
||||
from diffsynth.models import ModelManager
|
||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline
|
||||
|
||||
@@ -1,4 +1,263 @@
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames
|
||||
import torch, os, json
|
||||
import numpy as np
|
||||
|
||||
st.markdown("# Coming soon")
|
||||
|
||||
class Runner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def load_pipeline(self, model_list, textual_inversion_folder, device, controlnet_units):
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||
model_manager.load_textual_inversions(textual_inversion_folder)
|
||||
model_manager.load_models(model_list)
|
||||
pipe = SDVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
[
|
||||
ControlNetConfigUnit(
|
||||
processor_id=unit["processor_id"],
|
||||
model_path=unit["model_path"],
|
||||
scale=unit["scale"]
|
||||
) for unit in controlnet_units
|
||||
]
|
||||
)
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs):
|
||||
torch.manual_seed(seed)
|
||||
progress_bar_st = st.progress(0.0)
|
||||
output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st)
|
||||
progress_bar_st.progress(1.0)
|
||||
model_manager.to("cpu")
|
||||
return output_video
|
||||
|
||||
|
||||
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
||||
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
||||
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
||||
return frames
|
||||
|
||||
|
||||
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
||||
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
||||
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
||||
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
||||
if len(data["controlnet_frames"]) > 0:
|
||||
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
||||
return pipeline_inputs
|
||||
|
||||
|
||||
def save_output(self, video, output_folder, fps, config):
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
save_frames(video, os.path.join(output_folder, "frames"))
|
||||
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
||||
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
||||
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
||||
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
||||
json.dump(config, file, indent=4)
|
||||
|
||||
|
||||
def run(self, config):
|
||||
st.markdown("Loading videos ...")
|
||||
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
||||
st.markdown("Loading videos ... done!")
|
||||
st.markdown("Loading models ...")
|
||||
model_manager, pipe = self.load_pipeline(**config["models"])
|
||||
st.markdown("Loading models ... done!")
|
||||
st.markdown("Synthesizing videos ...")
|
||||
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"])
|
||||
st.markdown("Synthesizing videos ... done!")
|
||||
st.markdown("Saving videos ...")
|
||||
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
||||
st.markdown("Saving videos ... done!")
|
||||
st.markdown("Finished!")
|
||||
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
||||
st.video(video_file.read())
|
||||
|
||||
|
||||
|
||||
def load_model_list(folder):
|
||||
file_list = os.listdir(folder)
|
||||
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
|
||||
file_list = sorted(file_list)
|
||||
return file_list
|
||||
|
||||
|
||||
def match_processor_id(model_name, supported_processor_id_list):
|
||||
sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
|
||||
for processor_id in sorted_processor_id:
|
||||
if processor_id in model_name:
|
||||
return supported_processor_id_list.index(processor_id) + 1
|
||||
return 0
|
||||
|
||||
|
||||
config = {
|
||||
"models": {
|
||||
"model_list": [],
|
||||
"textual_inversion_folder": "models/textual_inversion",
|
||||
"device": "cuda",
|
||||
"controlnet_units": []
|
||||
},
|
||||
"data": {
|
||||
"input_frames": None,
|
||||
"controlnet_frames": [],
|
||||
"output_folder": "output",
|
||||
"fps": 60
|
||||
},
|
||||
"pipeline": {
|
||||
"seed": 0,
|
||||
"pipeline_inputs": {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
with st.expander("Model", expanded=True):
|
||||
stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
|
||||
if stable_diffusion_ckpt != "None":
|
||||
config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
|
||||
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
|
||||
if animatediff_ckpt != "None":
|
||||
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
|
||||
|
||||
|
||||
with st.expander("Data", expanded=True):
|
||||
with st.container(border=True):
|
||||
input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
|
||||
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
|
||||
with column_height:
|
||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
|
||||
with column_width:
|
||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
|
||||
with column_start_frame_index:
|
||||
start_frame_id = st.number_input("Start Frame id", value=0)
|
||||
with column_end_frame_index:
|
||||
end_frame_id = st.number_input("End Frame id", value=16)
|
||||
if input_video != "":
|
||||
config["data"]["input_frames"] = {
|
||||
"video_file": input_video,
|
||||
"image_folder": None,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"start_frame_id": start_frame_id,
|
||||
"end_frame_id": end_frame_id
|
||||
}
|
||||
with st.container(border=True):
|
||||
output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
|
||||
fps = st.number_input("FPS", value=60)
|
||||
config["data"]["output_folder"] = output_video
|
||||
config["data"]["fps"] = fps
|
||||
|
||||
|
||||
with st.expander("ControlNet Units", expanded=True):
|
||||
supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
|
||||
controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
|
||||
for controlnet_id in range(len(controlnet_units)):
|
||||
with controlnet_units[controlnet_id]:
|
||||
controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
|
||||
key=f"controlnet_ckpt_{controlnet_id}")
|
||||
processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
|
||||
index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
|
||||
disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
|
||||
controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
|
||||
use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
|
||||
disabled=controlnet_ckpt == "None",
|
||||
key=f"use_input_video_as_controlnet_input_{controlnet_id}")
|
||||
if not use_input_video_as_controlnet_input:
|
||||
controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
|
||||
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
|
||||
with column_height:
|
||||
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
|
||||
with column_width:
|
||||
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
|
||||
with column_start_frame_index:
|
||||
start_frame_id = st.number_input("Start Frame id", value=0,
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
|
||||
with column_end_frame_index:
|
||||
end_frame_id = st.number_input("End Frame id", value=16,
|
||||
disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
|
||||
if input_video != "":
|
||||
config["data"]["input_video"] = {
|
||||
"video_file": input_video,
|
||||
"image_folder": None,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"start_frame_id": start_frame_id,
|
||||
"end_frame_id": end_frame_id
|
||||
}
|
||||
if controlnet_ckpt != "None":
|
||||
config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
|
||||
config["models"]["controlnet_units"].append({
|
||||
"processor_id": processor_id,
|
||||
"model_path": os.path.join("models/ControlNet", controlnet_ckpt),
|
||||
"scale": controlnet_scale,
|
||||
})
|
||||
if use_input_video_as_controlnet_input:
|
||||
config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
|
||||
else:
|
||||
config["data"]["controlnet_frames"].append({
|
||||
"video_file": input_video,
|
||||
"image_folder": None,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"start_frame_id": start_frame_id,
|
||||
"end_frame_id": end_frame_id
|
||||
})
|
||||
|
||||
|
||||
with st.container(border=True):
|
||||
with st.expander("Seed", expanded=True):
|
||||
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
|
||||
if use_fixed_seed:
|
||||
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
|
||||
else:
|
||||
seed = np.random.randint(0, 10**9)
|
||||
with st.expander("Textual Guidance", expanded=True):
|
||||
prompt = st.text_area("Positive prompt")
|
||||
negative_prompt = st.text_area("Negative prompt")
|
||||
column_cfg_scale, column_clip_skip = st.columns(2)
|
||||
with column_cfg_scale:
|
||||
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
|
||||
with column_clip_skip:
|
||||
clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
|
||||
with st.expander("Denoising", expanded=True):
|
||||
column_num_inference_steps, column_denoising_strength = st.columns(2)
|
||||
with column_num_inference_steps:
|
||||
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
|
||||
with column_denoising_strength:
|
||||
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
|
||||
with st.expander("Efficiency", expanded=False):
|
||||
animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
|
||||
animatediff_stride = st.slider("Animatediff stride",
|
||||
min_value=1,
|
||||
max_value=max(2, animatediff_batch_size),
|
||||
value=max(1, animatediff_batch_size // 2),
|
||||
step=1)
|
||||
unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
|
||||
controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
|
||||
cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
|
||||
config["pipeline"]["seed"] = seed
|
||||
config["pipeline"]["pipeline_inputs"] = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"cfg_scale": cfg_scale,
|
||||
"clip_skip": clip_skip,
|
||||
"denoising_strength": denoising_strength,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"animatediff_batch_size": animatediff_batch_size,
|
||||
"animatediff_stride": animatediff_stride,
|
||||
"unet_batch_size": unet_batch_size,
|
||||
"controlnet_batch_size": controlnet_batch_size,
|
||||
"cross_frame_attention": cross_frame_attention,
|
||||
}
|
||||
|
||||
run_button = st.button("☢️Run☢️", type="primary")
|
||||
if run_button:
|
||||
Runner().run(config)
|
||||
Reference in New Issue
Block a user