diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py index 5074e75..4f97931 100644 --- a/diffsynth/pipelines/__init__.py +++ b/diffsynth/pipelines/__init__.py @@ -1,3 +1,3 @@ from .stable_diffusion import SDImagePipeline from .stable_diffusion_xl import SDXLImagePipeline -from .stable_diffusion_video import SDVideoPipeline +from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner diff --git a/diffsynth/pipelines/stable_diffusion_video.py b/diffsynth/pipelines/stable_diffusion_video.py index fe876e6..0eefb22 100644 --- a/diffsynth/pipelines/stable_diffusion_video.py +++ b/diffsynth/pipelines/stable_diffusion_video.py @@ -2,9 +2,10 @@ from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEnc from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..prompts import SDPrompter from ..schedulers import EnhancedDDIMScheduler +from ..data import VideoData, save_frames, save_video from .dancer import lets_dance from typing import List -import torch +import torch, os, json from tqdm import tqdm from PIL import Image import numpy as np @@ -254,3 +255,85 @@ class SDVideoPipeline(torch.nn.Module): output_frames = self.decode_images(latents) return output_frames + + + +class SDVideoPipelineRunner: + def __init__(self, in_streamlit=False): + self.in_streamlit = in_streamlit + + + def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units): + # Load models + model_manager = ModelManager(torch_dtype=torch.float16, device=device) + model_manager.load_textual_inversions(textual_inversion_folder) + model_manager.load_models(model_list, lora_alphas=lora_alphas) + pipe = SDVideoPipeline.from_model_manager( + model_manager, + [ + ControlNetConfigUnit( + processor_id=unit["processor_id"], + model_path=unit["model_path"], + scale=unit["scale"] + ) for unit in controlnet_units + ] + ) + return model_manager, pipe + + + def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs): + torch.manual_seed(seed) + if self.in_streamlit: + import streamlit as st + progress_bar_st = st.progress(0.0) + output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st) + progress_bar_st.progress(1.0) + else: + output_video = pipe(**pipeline_inputs) + model_manager.to("cpu") + return output_video + + + def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id): + video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width) + frames = [video[i] for i in range(start_frame_id, end_frame_id)] + return frames + + + def add_data_to_pipeline_inputs(self, data, pipeline_inputs): + pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"]) + pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"]) + pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size + if len(data["controlnet_frames"]) > 0: + pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]] + return pipeline_inputs + + + def save_output(self, video, output_folder, fps, config): + os.makedirs(output_folder, exist_ok=True) + save_frames(video, os.path.join(output_folder, "frames")) + save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps) + config["pipeline"]["pipeline_inputs"]["input_frames"] = [] + config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = [] + with open(os.path.join(output_folder, "config.json"), 'w') as file: + json.dump(config, file, indent=4) + + + def run(self, config): + if self.in_streamlit: + import streamlit as st + if self.in_streamlit: st.markdown("Loading videos ...") + config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"]) + if self.in_streamlit: st.markdown("Loading videos ... done!") + if self.in_streamlit: st.markdown("Loading models ...") + model_manager, pipe = self.load_pipeline(**config["models"]) + if self.in_streamlit: st.markdown("Loading models ... done!") + if self.in_streamlit: st.markdown("Synthesizing videos ...") + output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"]) + if self.in_streamlit: st.markdown("Synthesizing videos ... done!") + if self.in_streamlit: st.markdown("Saving videos ...") + self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config) + if self.in_streamlit: st.markdown("Saving videos ... done!") + if self.in_streamlit: st.markdown("Finished!") + video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb') + if self.in_streamlit: st.video(video_file.read()) diff --git a/pages/2_Video_Creator.py b/pages/2_Video_Creator.py index 15ebb5e..8748072 100644 --- a/pages/2_Video_Creator.py +++ b/pages/2_Video_Creator.py @@ -1,86 +1,10 @@ import streamlit as st st.set_page_config(layout="wide") -from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames -import torch, os, json +from diffsynth import SDVideoPipelineRunner +import os import numpy as np -class Runner: - def __init__(self): - pass - - - def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units): - # Load models - model_manager = ModelManager(torch_dtype=torch.float16, device=device) - model_manager.load_textual_inversions(textual_inversion_folder) - model_manager.load_models(model_list, lora_alphas=lora_alphas) - pipe = SDVideoPipeline.from_model_manager( - model_manager, - [ - ControlNetConfigUnit( - processor_id=unit["processor_id"], - model_path=unit["model_path"], - scale=unit["scale"] - ) for unit in controlnet_units - ] - ) - return model_manager, pipe - - - def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs): - torch.manual_seed(seed) - progress_bar_st = st.progress(0.0) - output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st) - progress_bar_st.progress(1.0) - model_manager.to("cpu") - return output_video - - - def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id): - video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width) - frames = [video[i] for i in range(start_frame_id, end_frame_id)] - return frames - - - def add_data_to_pipeline_inputs(self, data, pipeline_inputs): - pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"]) - pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"]) - pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size - if len(data["controlnet_frames"]) > 0: - pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]] - return pipeline_inputs - - - def save_output(self, video, output_folder, fps, config): - os.makedirs(output_folder, exist_ok=True) - save_frames(video, os.path.join(output_folder, "frames")) - save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps) - config["pipeline"]["pipeline_inputs"]["input_frames"] = [] - config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = [] - with open(os.path.join(output_folder, "config.json"), 'w') as file: - json.dump(config, file, indent=4) - - - def run(self, config): - st.markdown("Loading videos ...") - config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"]) - st.markdown("Loading videos ... done!") - st.markdown("Loading models ...") - model_manager, pipe = self.load_pipeline(**config["models"]) - st.markdown("Loading models ... done!") - st.markdown("Synthesizing videos ...") - output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"]) - st.markdown("Synthesizing videos ... done!") - st.markdown("Saving videos ...") - self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config) - st.markdown("Saving videos ... done!") - st.markdown("Finished!") - video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb') - st.video(video_file.read()) - - - def load_model_list(folder): file_list = os.listdir(folder) file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")] @@ -270,4 +194,4 @@ with st.container(border=True): run_button = st.button("☢️Run☢️", type="primary") if run_button: - Runner().run(config) \ No newline at end of file + SDVideoPipelineRunner(in_streamlit=True).run(config)