mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
add independent config runner
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .stable_diffusion import SDImagePipeline
|
||||
from .stable_diffusion_xl import SDXLImagePipeline
|
||||
from .stable_diffusion_video import SDVideoPipeline
|
||||
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
|
||||
|
||||
@@ -2,9 +2,10 @@ from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEnc
|
||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||
from ..prompts import SDPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from ..data import VideoData, save_frames, save_video
|
||||
from .dancer import lets_dance
|
||||
from typing import List
|
||||
import torch
|
||||
import torch, os, json
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
@@ -254,3 +255,85 @@ class SDVideoPipeline(torch.nn.Module):
|
||||
output_frames = self.decode_images(latents)
|
||||
|
||||
return output_frames
|
||||
|
||||
|
||||
|
||||
class SDVideoPipelineRunner:
|
||||
def __init__(self, in_streamlit=False):
|
||||
self.in_streamlit = in_streamlit
|
||||
|
||||
|
||||
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||
model_manager.load_textual_inversions(textual_inversion_folder)
|
||||
model_manager.load_models(model_list, lora_alphas=lora_alphas)
|
||||
pipe = SDVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
[
|
||||
ControlNetConfigUnit(
|
||||
processor_id=unit["processor_id"],
|
||||
model_path=unit["model_path"],
|
||||
scale=unit["scale"]
|
||||
) for unit in controlnet_units
|
||||
]
|
||||
)
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs):
|
||||
torch.manual_seed(seed)
|
||||
if self.in_streamlit:
|
||||
import streamlit as st
|
||||
progress_bar_st = st.progress(0.0)
|
||||
output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st)
|
||||
progress_bar_st.progress(1.0)
|
||||
else:
|
||||
output_video = pipe(**pipeline_inputs)
|
||||
model_manager.to("cpu")
|
||||
return output_video
|
||||
|
||||
|
||||
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
||||
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
||||
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
||||
return frames
|
||||
|
||||
|
||||
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
||||
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
||||
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
||||
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
||||
if len(data["controlnet_frames"]) > 0:
|
||||
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
||||
return pipeline_inputs
|
||||
|
||||
|
||||
def save_output(self, video, output_folder, fps, config):
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
save_frames(video, os.path.join(output_folder, "frames"))
|
||||
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
||||
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
||||
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
||||
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
||||
json.dump(config, file, indent=4)
|
||||
|
||||
|
||||
def run(self, config):
|
||||
if self.in_streamlit:
|
||||
import streamlit as st
|
||||
if self.in_streamlit: st.markdown("Loading videos ...")
|
||||
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
||||
if self.in_streamlit: st.markdown("Loading videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Loading models ...")
|
||||
model_manager, pipe = self.load_pipeline(**config["models"])
|
||||
if self.in_streamlit: st.markdown("Loading models ... done!")
|
||||
if self.in_streamlit: st.markdown("Synthesizing videos ...")
|
||||
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"])
|
||||
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Saving videos ...")
|
||||
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
||||
if self.in_streamlit: st.markdown("Saving videos ... done!")
|
||||
if self.in_streamlit: st.markdown("Finished!")
|
||||
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
||||
if self.in_streamlit: st.video(video_file.read())
|
||||
|
||||
@@ -1,86 +1,10 @@
|
||||
import streamlit as st
|
||||
st.set_page_config(layout="wide")
|
||||
from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames
|
||||
import torch, os, json
|
||||
from diffsynth import SDVideoPipelineRunner
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Runner:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||
model_manager.load_textual_inversions(textual_inversion_folder)
|
||||
model_manager.load_models(model_list, lora_alphas=lora_alphas)
|
||||
pipe = SDVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
[
|
||||
ControlNetConfigUnit(
|
||||
processor_id=unit["processor_id"],
|
||||
model_path=unit["model_path"],
|
||||
scale=unit["scale"]
|
||||
) for unit in controlnet_units
|
||||
]
|
||||
)
|
||||
return model_manager, pipe
|
||||
|
||||
|
||||
def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs):
|
||||
torch.manual_seed(seed)
|
||||
progress_bar_st = st.progress(0.0)
|
||||
output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st)
|
||||
progress_bar_st.progress(1.0)
|
||||
model_manager.to("cpu")
|
||||
return output_video
|
||||
|
||||
|
||||
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
||||
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
||||
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
||||
return frames
|
||||
|
||||
|
||||
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
||||
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
||||
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
||||
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
||||
if len(data["controlnet_frames"]) > 0:
|
||||
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
||||
return pipeline_inputs
|
||||
|
||||
|
||||
def save_output(self, video, output_folder, fps, config):
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
save_frames(video, os.path.join(output_folder, "frames"))
|
||||
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
||||
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
||||
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
||||
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
||||
json.dump(config, file, indent=4)
|
||||
|
||||
|
||||
def run(self, config):
|
||||
st.markdown("Loading videos ...")
|
||||
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
||||
st.markdown("Loading videos ... done!")
|
||||
st.markdown("Loading models ...")
|
||||
model_manager, pipe = self.load_pipeline(**config["models"])
|
||||
st.markdown("Loading models ... done!")
|
||||
st.markdown("Synthesizing videos ...")
|
||||
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"])
|
||||
st.markdown("Synthesizing videos ... done!")
|
||||
st.markdown("Saving videos ...")
|
||||
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
||||
st.markdown("Saving videos ... done!")
|
||||
st.markdown("Finished!")
|
||||
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
||||
st.video(video_file.read())
|
||||
|
||||
|
||||
|
||||
def load_model_list(folder):
|
||||
file_list = os.listdir(folder)
|
||||
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
|
||||
@@ -270,4 +194,4 @@ with st.container(border=True):
|
||||
|
||||
run_button = st.button("☢️Run☢️", type="primary")
|
||||
if run_button:
|
||||
Runner().run(config)
|
||||
SDVideoPipelineRunner(in_streamlit=True).run(config)
|
||||
|
||||
Reference in New Issue
Block a user