mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
add video UI
This commit is contained in:
@@ -12,6 +12,8 @@ Create Python environment:
|
|||||||
conda env create -f environment.yml
|
conda env create -f environment.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
|
||||||
|
|
||||||
Enter the Python environment:
|
Enter the Python environment:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -23,13 +23,11 @@ class MultiControlNetManager:
|
|||||||
self.models = [unit.model for unit in controlnet_units]
|
self.models = [unit.model for unit in controlnet_units]
|
||||||
self.scales = [unit.scale for unit in controlnet_units]
|
self.scales = [unit.scale for unit in controlnet_units]
|
||||||
|
|
||||||
def process_image(self, image, return_image=False):
|
def process_image(self, image, processor_id=None):
|
||||||
processed_image = [
|
if processor_id is None:
|
||||||
processor(image)
|
processed_image = [processor(image) for processor in self.processors]
|
||||||
for processor in self.processors
|
else:
|
||||||
]
|
processed_image = [self.processors[processor_id](image)]
|
||||||
if return_image:
|
|
||||||
return processed_image
|
|
||||||
processed_image = torch.concat([
|
processed_image = torch.concat([
|
||||||
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
||||||
for image_ in processed_image
|
for image_ in processed_image
|
||||||
|
|||||||
@@ -16,15 +16,15 @@ class Annotator:
|
|||||||
if processor_id == "canny":
|
if processor_id == "canny":
|
||||||
self.processor = CannyDetector()
|
self.processor = CannyDetector()
|
||||||
elif processor_id == "depth":
|
elif processor_id == "depth":
|
||||||
self.processor = MidasDetector.from_pretrained(model_path)
|
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
|
||||||
elif processor_id == "softedge":
|
elif processor_id == "softedge":
|
||||||
self.processor = HEDdetector.from_pretrained(model_path)
|
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
|
||||||
elif processor_id == "lineart":
|
elif processor_id == "lineart":
|
||||||
self.processor = LineartDetector.from_pretrained(model_path)
|
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
|
||||||
elif processor_id == "lineart_anime":
|
elif processor_id == "lineart_anime":
|
||||||
self.processor = LineartAnimeDetector.from_pretrained(model_path)
|
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
|
||||||
elif processor_id == "openpose":
|
elif processor_id == "openpose":
|
||||||
self.processor = OpenposeDetector.from_pretrained(model_path)
|
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
|
||||||
elif processor_id == "tile":
|
elif processor_id == "tile":
|
||||||
self.processor = None
|
self.processor = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import cupy as cp
|
|||||||
class FastBlendSmoother:
|
class FastBlendSmoother:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.batch_size = 8
|
self.batch_size = 8
|
||||||
self.window_size = 32
|
self.window_size = 64
|
||||||
self.ebsynth_config = {
|
self.ebsynth_config = {
|
||||||
"minimum_patch_size": 5,
|
"minimum_patch_size": 5,
|
||||||
"threads_per_block": 8,
|
"threads_per_block": 8,
|
||||||
|
|||||||
@@ -189,6 +189,7 @@ class ModelManager:
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
else:
|
else:
|
||||||
self.model[component].to(device)
|
self.model[component].to(device)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_model_with_model_path(self, model_path):
|
def get_model_with_model_path(self, model_path):
|
||||||
for component in self.model_path:
|
for component in self.model_path:
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ def lets_dance_with_long_video(
|
|||||||
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias))
|
||||||
hidden_states_output[i] = (hidden_states, num + 1)
|
hidden_states_output[i] = (hidden_states, num + 1)
|
||||||
|
|
||||||
|
if batch_id_ == num_frames:
|
||||||
|
break
|
||||||
|
|
||||||
# output
|
# output
|
||||||
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
hidden_states = torch.stack([h for h, _ in hidden_states_output])
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -195,6 +198,17 @@ class SDVideoPipeline(torch.nn.Module):
|
|||||||
|
|
||||||
# Prepare ControlNets
|
# Prepare ControlNets
|
||||||
if controlnet_frames is not None:
|
if controlnet_frames is not None:
|
||||||
|
if isinstance(controlnet_frames[0], list):
|
||||||
|
controlnet_frames_ = []
|
||||||
|
for processor_id in range(len(controlnet_frames)):
|
||||||
|
controlnet_frames_.append(
|
||||||
|
torch.stack([
|
||||||
|
self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype)
|
||||||
|
for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id])
|
||||||
|
], dim=1)
|
||||||
|
)
|
||||||
|
controlnet_frames = torch.concat(controlnet_frames_, dim=0)
|
||||||
|
else:
|
||||||
controlnet_frames = torch.stack([
|
controlnet_frames = torch.stack([
|
||||||
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
self.controlnet.process_image(controlnet_frame).to(self.torch_dtype)
|
||||||
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
for controlnet_frame in progress_bar_cmd(controlnet_frames)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ dependencies:
|
|||||||
- pip=23.0.1
|
- pip=23.0.1
|
||||||
- cudatoolkit
|
- cudatoolkit
|
||||||
- pytorch
|
- pytorch
|
||||||
|
- cupy
|
||||||
- pip:
|
- pip:
|
||||||
- transformers
|
- transformers
|
||||||
- controlnet-aux==0.0.7
|
- controlnet-aux==0.0.7
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import torch, os, io
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
st.set_page_config(layout="wide")
|
|
||||||
from streamlit_drawable_canvas import st_canvas
|
from streamlit_drawable_canvas import st_canvas
|
||||||
from diffsynth.models import ModelManager
|
from diffsynth.models import ModelManager
|
||||||
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline
|
from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline
|
||||||
|
|||||||
@@ -1,4 +1,263 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
st.set_page_config(layout="wide")
|
from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames
|
||||||
|
import torch, os, json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
st.markdown("# Coming soon")
|
|
||||||
|
class Runner:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def load_pipeline(self, model_list, textual_inversion_folder, device, controlnet_units):
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||||
|
model_manager.load_textual_inversions(textual_inversion_folder)
|
||||||
|
model_manager.load_models(model_list)
|
||||||
|
pipe = SDVideoPipeline.from_model_manager(
|
||||||
|
model_manager,
|
||||||
|
[
|
||||||
|
ControlNetConfigUnit(
|
||||||
|
processor_id=unit["processor_id"],
|
||||||
|
model_path=unit["model_path"],
|
||||||
|
scale=unit["scale"]
|
||||||
|
) for unit in controlnet_units
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return model_manager, pipe
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
progress_bar_st = st.progress(0.0)
|
||||||
|
output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st)
|
||||||
|
progress_bar_st.progress(1.0)
|
||||||
|
model_manager.to("cpu")
|
||||||
|
return output_video
|
||||||
|
|
||||||
|
|
||||||
|
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
||||||
|
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
||||||
|
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
||||||
|
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
||||||
|
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
||||||
|
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
||||||
|
if len(data["controlnet_frames"]) > 0:
|
||||||
|
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
||||||
|
return pipeline_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def save_output(self, video, output_folder, fps, config):
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
save_frames(video, os.path.join(output_folder, "frames"))
|
||||||
|
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
||||||
|
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
||||||
|
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
||||||
|
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
||||||
|
json.dump(config, file, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, config):
|
||||||
|
st.markdown("Loading videos ...")
|
||||||
|
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
||||||
|
st.markdown("Loading videos ... done!")
|
||||||
|
st.markdown("Loading models ...")
|
||||||
|
model_manager, pipe = self.load_pipeline(**config["models"])
|
||||||
|
st.markdown("Loading models ... done!")
|
||||||
|
st.markdown("Synthesizing videos ...")
|
||||||
|
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"])
|
||||||
|
st.markdown("Synthesizing videos ... done!")
|
||||||
|
st.markdown("Saving videos ...")
|
||||||
|
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
||||||
|
st.markdown("Saving videos ... done!")
|
||||||
|
st.markdown("Finished!")
|
||||||
|
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
||||||
|
st.video(video_file.read())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_list(folder):
|
||||||
|
file_list = os.listdir(folder)
|
||||||
|
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
|
||||||
|
file_list = sorted(file_list)
|
||||||
|
return file_list
|
||||||
|
|
||||||
|
|
||||||
|
def match_processor_id(model_name, supported_processor_id_list):
|
||||||
|
sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])]
|
||||||
|
for processor_id in sorted_processor_id:
|
||||||
|
if processor_id in model_name:
|
||||||
|
return supported_processor_id_list.index(processor_id) + 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"models": {
|
||||||
|
"model_list": [],
|
||||||
|
"textual_inversion_folder": "models/textual_inversion",
|
||||||
|
"device": "cuda",
|
||||||
|
"controlnet_units": []
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"input_frames": None,
|
||||||
|
"controlnet_frames": [],
|
||||||
|
"output_folder": "output",
|
||||||
|
"fps": 60
|
||||||
|
},
|
||||||
|
"pipeline": {
|
||||||
|
"seed": 0,
|
||||||
|
"pipeline_inputs": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
with st.expander("Model", expanded=True):
|
||||||
|
stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion"))
|
||||||
|
if stable_diffusion_ckpt != "None":
|
||||||
|
config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt))
|
||||||
|
animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff"))
|
||||||
|
if animatediff_ckpt != "None":
|
||||||
|
config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt))
|
||||||
|
|
||||||
|
|
||||||
|
with st.expander("Data", expanded=True):
|
||||||
|
with st.container(border=True):
|
||||||
|
input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="")
|
||||||
|
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
|
||||||
|
with column_height:
|
||||||
|
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
|
||||||
|
with column_width:
|
||||||
|
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024)
|
||||||
|
with column_start_frame_index:
|
||||||
|
start_frame_id = st.number_input("Start Frame id", value=0)
|
||||||
|
with column_end_frame_index:
|
||||||
|
end_frame_id = st.number_input("End Frame id", value=16)
|
||||||
|
if input_video != "":
|
||||||
|
config["data"]["input_frames"] = {
|
||||||
|
"video_file": input_video,
|
||||||
|
"image_folder": None,
|
||||||
|
"height": height,
|
||||||
|
"width": width,
|
||||||
|
"start_frame_id": start_frame_id,
|
||||||
|
"end_frame_id": end_frame_id
|
||||||
|
}
|
||||||
|
with st.container(border=True):
|
||||||
|
output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output")
|
||||||
|
fps = st.number_input("FPS", value=60)
|
||||||
|
config["data"]["output_folder"] = output_video
|
||||||
|
config["data"]["fps"] = fps
|
||||||
|
|
||||||
|
|
||||||
|
with st.expander("ControlNet Units", expanded=True):
|
||||||
|
supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"]
|
||||||
|
controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"])
|
||||||
|
for controlnet_id in range(len(controlnet_units)):
|
||||||
|
with controlnet_units[controlnet_id]:
|
||||||
|
controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"),
|
||||||
|
key=f"controlnet_ckpt_{controlnet_id}")
|
||||||
|
processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list,
|
||||||
|
index=match_processor_id(controlnet_ckpt, supported_processor_id_list),
|
||||||
|
disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}")
|
||||||
|
controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5,
|
||||||
|
disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}")
|
||||||
|
use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True,
|
||||||
|
disabled=controlnet_ckpt == "None",
|
||||||
|
key=f"use_input_video_as_controlnet_input_{controlnet_id}")
|
||||||
|
if not use_input_video_as_controlnet_input:
|
||||||
|
controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="",
|
||||||
|
disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}")
|
||||||
|
column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1])
|
||||||
|
with column_height:
|
||||||
|
height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
|
||||||
|
disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}")
|
||||||
|
with column_width:
|
||||||
|
width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024,
|
||||||
|
disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}")
|
||||||
|
with column_start_frame_index:
|
||||||
|
start_frame_id = st.number_input("Start Frame id", value=0,
|
||||||
|
disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}")
|
||||||
|
with column_end_frame_index:
|
||||||
|
end_frame_id = st.number_input("End Frame id", value=16,
|
||||||
|
disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}")
|
||||||
|
if input_video != "":
|
||||||
|
config["data"]["input_video"] = {
|
||||||
|
"video_file": input_video,
|
||||||
|
"image_folder": None,
|
||||||
|
"height": height,
|
||||||
|
"width": width,
|
||||||
|
"start_frame_id": start_frame_id,
|
||||||
|
"end_frame_id": end_frame_id
|
||||||
|
}
|
||||||
|
if controlnet_ckpt != "None":
|
||||||
|
config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt))
|
||||||
|
config["models"]["controlnet_units"].append({
|
||||||
|
"processor_id": processor_id,
|
||||||
|
"model_path": os.path.join("models/ControlNet", controlnet_ckpt),
|
||||||
|
"scale": controlnet_scale,
|
||||||
|
})
|
||||||
|
if use_input_video_as_controlnet_input:
|
||||||
|
config["data"]["controlnet_frames"].append(config["data"]["input_frames"])
|
||||||
|
else:
|
||||||
|
config["data"]["controlnet_frames"].append({
|
||||||
|
"video_file": input_video,
|
||||||
|
"image_folder": None,
|
||||||
|
"height": height,
|
||||||
|
"width": width,
|
||||||
|
"start_frame_id": start_frame_id,
|
||||||
|
"end_frame_id": end_frame_id
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
with st.container(border=True):
|
||||||
|
with st.expander("Seed", expanded=True):
|
||||||
|
use_fixed_seed = st.checkbox("Use fixed seed", value=False)
|
||||||
|
if use_fixed_seed:
|
||||||
|
seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0)
|
||||||
|
else:
|
||||||
|
seed = np.random.randint(0, 10**9)
|
||||||
|
with st.expander("Textual Guidance", expanded=True):
|
||||||
|
prompt = st.text_area("Positive prompt")
|
||||||
|
negative_prompt = st.text_area("Negative prompt")
|
||||||
|
column_cfg_scale, column_clip_skip = st.columns(2)
|
||||||
|
with column_cfg_scale:
|
||||||
|
cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0)
|
||||||
|
with column_clip_skip:
|
||||||
|
clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1)
|
||||||
|
with st.expander("Denoising", expanded=True):
|
||||||
|
column_num_inference_steps, column_denoising_strength = st.columns(2)
|
||||||
|
with column_num_inference_steps:
|
||||||
|
num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10)
|
||||||
|
with column_denoising_strength:
|
||||||
|
denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0)
|
||||||
|
with st.expander("Efficiency", expanded=False):
|
||||||
|
animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1)
|
||||||
|
animatediff_stride = st.slider("Animatediff stride",
|
||||||
|
min_value=1,
|
||||||
|
max_value=max(2, animatediff_batch_size),
|
||||||
|
value=max(1, animatediff_batch_size // 2),
|
||||||
|
step=1)
|
||||||
|
unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1)
|
||||||
|
controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1)
|
||||||
|
cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False)
|
||||||
|
config["pipeline"]["seed"] = seed
|
||||||
|
config["pipeline"]["pipeline_inputs"] = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"clip_skip": clip_skip,
|
||||||
|
"denoising_strength": denoising_strength,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
"animatediff_batch_size": animatediff_batch_size,
|
||||||
|
"animatediff_stride": animatediff_stride,
|
||||||
|
"unet_batch_size": unet_batch_size,
|
||||||
|
"controlnet_batch_size": controlnet_batch_size,
|
||||||
|
"cross_frame_attention": cross_frame_attention,
|
||||||
|
}
|
||||||
|
|
||||||
|
run_button = st.button("☢️Run☢️", type="primary")
|
||||||
|
if run_button:
|
||||||
|
Runner().run(config)
|
||||||
Reference in New Issue
Block a user