mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
add independent config runner
This commit is contained in:
@@ -1,3 +1,3 @@
|
|||||||
from .stable_diffusion import SDImagePipeline
|
from .stable_diffusion import SDImagePipeline
|
||||||
from .stable_diffusion_xl import SDXLImagePipeline
|
from .stable_diffusion_xl import SDXLImagePipeline
|
||||||
from .stable_diffusion_video import SDVideoPipeline
|
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ from ..models import ModelManager, SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEnc
|
|||||||
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
from ..controlnets import MultiControlNetManager, ControlNetUnit, ControlNetConfigUnit, Annotator
|
||||||
from ..prompts import SDPrompter
|
from ..prompts import SDPrompter
|
||||||
from ..schedulers import EnhancedDDIMScheduler
|
from ..schedulers import EnhancedDDIMScheduler
|
||||||
|
from ..data import VideoData, save_frames, save_video
|
||||||
from .dancer import lets_dance
|
from .dancer import lets_dance
|
||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch, os, json
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -254,3 +255,85 @@ class SDVideoPipeline(torch.nn.Module):
|
|||||||
output_frames = self.decode_images(latents)
|
output_frames = self.decode_images(latents)
|
||||||
|
|
||||||
return output_frames
|
return output_frames
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SDVideoPipelineRunner:
|
||||||
|
def __init__(self, in_streamlit=False):
|
||||||
|
self.in_streamlit = in_streamlit
|
||||||
|
|
||||||
|
|
||||||
|
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
||||||
|
model_manager.load_textual_inversions(textual_inversion_folder)
|
||||||
|
model_manager.load_models(model_list, lora_alphas=lora_alphas)
|
||||||
|
pipe = SDVideoPipeline.from_model_manager(
|
||||||
|
model_manager,
|
||||||
|
[
|
||||||
|
ControlNetConfigUnit(
|
||||||
|
processor_id=unit["processor_id"],
|
||||||
|
model_path=unit["model_path"],
|
||||||
|
scale=unit["scale"]
|
||||||
|
) for unit in controlnet_units
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return model_manager, pipe
|
||||||
|
|
||||||
|
|
||||||
|
def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if self.in_streamlit:
|
||||||
|
import streamlit as st
|
||||||
|
progress_bar_st = st.progress(0.0)
|
||||||
|
output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st)
|
||||||
|
progress_bar_st.progress(1.0)
|
||||||
|
else:
|
||||||
|
output_video = pipe(**pipeline_inputs)
|
||||||
|
model_manager.to("cpu")
|
||||||
|
return output_video
|
||||||
|
|
||||||
|
|
||||||
|
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
||||||
|
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
||||||
|
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
||||||
|
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
||||||
|
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
||||||
|
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
||||||
|
if len(data["controlnet_frames"]) > 0:
|
||||||
|
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
||||||
|
return pipeline_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def save_output(self, video, output_folder, fps, config):
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
save_frames(video, os.path.join(output_folder, "frames"))
|
||||||
|
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
||||||
|
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
||||||
|
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
||||||
|
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
||||||
|
json.dump(config, file, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, config):
|
||||||
|
if self.in_streamlit:
|
||||||
|
import streamlit as st
|
||||||
|
if self.in_streamlit: st.markdown("Loading videos ...")
|
||||||
|
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
||||||
|
if self.in_streamlit: st.markdown("Loading videos ... done!")
|
||||||
|
if self.in_streamlit: st.markdown("Loading models ...")
|
||||||
|
model_manager, pipe = self.load_pipeline(**config["models"])
|
||||||
|
if self.in_streamlit: st.markdown("Loading models ... done!")
|
||||||
|
if self.in_streamlit: st.markdown("Synthesizing videos ...")
|
||||||
|
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"])
|
||||||
|
if self.in_streamlit: st.markdown("Synthesizing videos ... done!")
|
||||||
|
if self.in_streamlit: st.markdown("Saving videos ...")
|
||||||
|
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
||||||
|
if self.in_streamlit: st.markdown("Saving videos ... done!")
|
||||||
|
if self.in_streamlit: st.markdown("Finished!")
|
||||||
|
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
||||||
|
if self.in_streamlit: st.video(video_file.read())
|
||||||
|
|||||||
@@ -1,86 +1,10 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
st.set_page_config(layout="wide")
|
st.set_page_config(layout="wide")
|
||||||
from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames
|
from diffsynth import SDVideoPipelineRunner
|
||||||
import torch, os, json
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Runner:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(self, model_list, textual_inversion_folder, device, lora_alphas, controlnet_units):
|
|
||||||
# Load models
|
|
||||||
model_manager = ModelManager(torch_dtype=torch.float16, device=device)
|
|
||||||
model_manager.load_textual_inversions(textual_inversion_folder)
|
|
||||||
model_manager.load_models(model_list, lora_alphas=lora_alphas)
|
|
||||||
pipe = SDVideoPipeline.from_model_manager(
|
|
||||||
model_manager,
|
|
||||||
[
|
|
||||||
ControlNetConfigUnit(
|
|
||||||
processor_id=unit["processor_id"],
|
|
||||||
model_path=unit["model_path"],
|
|
||||||
scale=unit["scale"]
|
|
||||||
) for unit in controlnet_units
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return model_manager, pipe
|
|
||||||
|
|
||||||
|
|
||||||
def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs):
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
progress_bar_st = st.progress(0.0)
|
|
||||||
output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st)
|
|
||||||
progress_bar_st.progress(1.0)
|
|
||||||
model_manager.to("cpu")
|
|
||||||
return output_video
|
|
||||||
|
|
||||||
|
|
||||||
def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id):
|
|
||||||
video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width)
|
|
||||||
frames = [video[i] for i in range(start_frame_id, end_frame_id)]
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
def add_data_to_pipeline_inputs(self, data, pipeline_inputs):
|
|
||||||
pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"])
|
|
||||||
pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"])
|
|
||||||
pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size
|
|
||||||
if len(data["controlnet_frames"]) > 0:
|
|
||||||
pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]]
|
|
||||||
return pipeline_inputs
|
|
||||||
|
|
||||||
|
|
||||||
def save_output(self, video, output_folder, fps, config):
|
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
|
||||||
save_frames(video, os.path.join(output_folder, "frames"))
|
|
||||||
save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps)
|
|
||||||
config["pipeline"]["pipeline_inputs"]["input_frames"] = []
|
|
||||||
config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = []
|
|
||||||
with open(os.path.join(output_folder, "config.json"), 'w') as file:
|
|
||||||
json.dump(config, file, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
def run(self, config):
|
|
||||||
st.markdown("Loading videos ...")
|
|
||||||
config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"])
|
|
||||||
st.markdown("Loading videos ... done!")
|
|
||||||
st.markdown("Loading models ...")
|
|
||||||
model_manager, pipe = self.load_pipeline(**config["models"])
|
|
||||||
st.markdown("Loading models ... done!")
|
|
||||||
st.markdown("Synthesizing videos ...")
|
|
||||||
output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"])
|
|
||||||
st.markdown("Synthesizing videos ... done!")
|
|
||||||
st.markdown("Saving videos ...")
|
|
||||||
self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config)
|
|
||||||
st.markdown("Saving videos ... done!")
|
|
||||||
st.markdown("Finished!")
|
|
||||||
video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb')
|
|
||||||
st.video(video_file.read())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(folder):
|
def load_model_list(folder):
|
||||||
file_list = os.listdir(folder)
|
file_list = os.listdir(folder)
|
||||||
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
|
file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")]
|
||||||
@@ -270,4 +194,4 @@ with st.container(border=True):
|
|||||||
|
|
||||||
run_button = st.button("☢️Run☢️", type="primary")
|
run_button = st.button("☢️Run☢️", type="primary")
|
||||||
if run_button:
|
if run_button:
|
||||||
Runner().run(config)
|
SDVideoPipelineRunner(in_streamlit=True).run(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user