add independent config runner

This commit is contained in:
Artiprocher
2024-01-11 14:02:17 +08:00
parent 9698e3988f
commit b28cb748b9
3 changed files with 88 additions and 81 deletions

View File

@@ -1,3 +1,3 @@
from .stable_diffusion import SDImagePipeline from .stable_diffusion import SDImagePipeline
from .stable_diffusion_xl import SDXLImagePipeline from .stable_diffusion_xl import SDXLImagePipeline
from .stable_diffusion_video import SDVideoPipeline from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner

View File

@@ -2,9 +2,10 @@ from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEnc
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
from ..prompts import SDPrompter from ..prompts import SDPrompter
from ..schedulers import EnhancedDDIMScheduler from ..schedulers import EnhancedDDIMScheduler
from ..data import VideoData, save_frames, save_video
from .dancer import lets_dance from .dancer import lets_dance
from typing import List from typing import List
import torch import torch, os, json
from tqdm import tqdm from tqdm import tqdm
from PIL import Image from PIL import Image
import numpy as np import numpy as np
@@ -254,3 +255,85 @@ class SDVideoPipeline(torch.nn.Module):
output_frames = self.decode_images(latents) output_frames = self.decode_images(latents)
return output_frames return output_frames
class SDVideoPipelineRunner:
def __init__(self, in_streamlit=False):
self.in_streamlit = in_streamlit
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
model_manager.load_textual_inversions(textual_inversion_folder)
model_manager.load_models(model_list, lora_alphas=lora_alphas)
pipe = SDVideoPipeline.from_model_manager(
model_manager,
[
ControlNetConfigUnit(
processor_id=unit["processor_id"],
model_path=unit["model_path"],
scale=unit["scale"]
) for unit in controlnet_units
]
)
return model_manager, pipe
def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs):
torch.manual_seed(seed)
if self.in_streamlit:
import streamlit as st
progress_bar_st = st.progress(0.0)
output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st)
progress_bar_st.progress(1.0)
else:
output_video = pipe(**pipeline_inputs)
model_manager.to("cpu")
return output_video
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
return frames
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
if len(data["controlnet_frames"]) > 0:
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
return pipeline_inputs
def save_output(self, video, output_folder, fps, config):
os.makedirs(output_folder, exist_ok=True)
save_frames(video, os.path.join(output_folder, "frames"))
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
with open(os.path.join(output_folder, "config.json"), 'w') as file:
json.dump(config, file, indent=4)
def run(self, config):
if self.in_streamlit:
import streamlit as st
if self.in_streamlit: st.markdown("Loading videos ...")
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Loading videos ... done!")
if self.in_streamlit: st.markdown("Loading models ...")
model_manager, pipe = self.load_pipeline(**config["models"])
if self.in_streamlit: st.markdown("Loading models ... done!")
if self.in_streamlit: st.markdown("Synthesizing videos ...")
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"])
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
if self.in_streamlit: st.markdown("Saving videos ...")
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
if self.in_streamlit: st.markdown("Saving videos ... done!")
if self.in_streamlit: st.markdown("Finished!")
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
if self.in_streamlit: st.video(video_file.read())

View File

@@ -1,86 +1,10 @@
import streamlit as st import streamlit as st
st.set_page_config(layout="wide") st.set_page_config(layout="wide")
from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames from diffsynth import SDVideoPipelineRunner
import torch, os, json import os
import numpy as np import numpy as np
class Runner:
def __init__(self):
pass
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
model_manager.load_textual_inversions(textual_inversion_folder)
model_manager.load_models(model_list, lora_alphas=lora_alphas)
pipe = SDVideoPipeline.from_model_manager(
model_manager,
[
ControlNetConfigUnit(
processor_id=unit["processor_id"],
model_path=unit["model_path"],
scale=unit["scale"]
) for unit in controlnet_units
]
)
return model_manager, pipe
def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs):
torch.manual_seed(seed)
progress_bar_st = st.progress(0.0)
output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st)
progress_bar_st.progress(1.0)
model_manager.to("cpu")
return output_video
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
return frames
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
if len(data["controlnet_frames"]) > 0:
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
return pipeline_inputs
def save_output(self, video, output_folder, fps, config):
os.makedirs(output_folder, exist_ok=True)
save_frames(video, os.path.join(output_folder, "frames"))
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
with open(os.path.join(output_folder, "config.json"), 'w') as file:
json.dump(config, file, indent=4)
def run(self, config):
st.markdown("Loading videos ...")
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
st.markdown("Loading videos ... done!")
st.markdown("Loading models ...")
model_manager, pipe = self.load_pipeline(**config["models"])
st.markdown("Loading models ... done!")
st.markdown("Synthesizing videos ...")
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"])
st.markdown("Synthesizing videos ... done!")
st.markdown("Saving videos ...")
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
st.markdown("Saving videos ... done!")
st.markdown("Finished!")
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
st.video(video_file.read())
def load_model_list(folder): def load_model_list(folder):
file_list = os.listdir(folder) file_list = os.listdir(folder)
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")] file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
@@ -270,4 +194,4 @@ with st.container(border=True):
run_button = st.button("Run☢", type="primary") run_button = st.button("Run☢", type="primary")
if run_button: if run_button:
Runner().run(config) SDVideoPipelineRunner(in_streamlit=True).run(config)