diff --git a/README.md b/README.md index afd6a8a..59df0b9 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ Create Python environment: conda env create -f environment.yml ``` +We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details. + Enter the Python environment: ``` diff --git a/diffsynth/controlnets/controlnet_unit.py b/diffsynth/controlnets/controlnet_unit.py index 3a0ad97..9754bae 100644 --- a/diffsynth/controlnets/controlnet_unit.py +++ b/diffsynth/controlnets/controlnet_unit.py @@ -23,13 +23,11 @@ class MultiControlNetManager: self.models = [unit.model for unit in controlnet_units] self.scales = [unit.scale for unit in controlnet_units] - def process_image(self, image, return_image=False): - processed_image = [ - processor(image) - for processor in self.processors - ] - if return_image: - return processed_image + def process_image(self, image, processor_id=None): + if processor_id is None: + processed_image = [processor(image) for processor in self.processors] + else: + processed_image = [self.processors[processor_id](image)] processed_image = torch.concat([ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) for image_ in processed_image diff --git a/diffsynth/controlnets/processors.py b/diffsynth/controlnets/processors.py index d6ea121..a378842 100644 --- a/diffsynth/controlnets/processors.py +++ b/diffsynth/controlnets/processors.py @@ -16,15 +16,15 @@ class Annotator: if processor_id == "canny": self.processor = CannyDetector() elif processor_id == "depth": - self.processor = MidasDetector.from_pretrained(model_path) + self.processor = MidasDetector.from_pretrained(model_path).to("cuda") elif processor_id == "softedge": - self.processor = HEDdetector.from_pretrained(model_path) + self.processor = HEDdetector.from_pretrained(model_path).to("cuda") elif processor_id == "lineart": - self.processor = LineartDetector.from_pretrained(model_path) + self.processor = LineartDetector.from_pretrained(model_path).to("cuda") elif processor_id == "lineart_anime": - self.processor = LineartAnimeDetector.from_pretrained(model_path) + self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda") elif processor_id == "openpose": - self.processor = OpenposeDetector.from_pretrained(model_path) + self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda") elif processor_id == "tile": self.processor = None else: diff --git a/diffsynth/extensions/FastBlend/__init__.py b/diffsynth/extensions/FastBlend/__init__.py index 5ff410a..2bf812c 100644 --- a/diffsynth/extensions/FastBlend/__init__.py +++ b/diffsynth/extensions/FastBlend/__init__.py @@ -7,7 +7,7 @@ import cupy as cp class FastBlendSmoother: def __init__(self): self.batch_size = 8 - self.window_size = 32 + self.window_size = 64 self.ebsynth_config = { "minimum_patch_size": 5, "threads_per_block": 8, diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index 899d0b5..9daab9a 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -189,6 +189,7 @@ class ModelManager: model.to(device) else: self.model[component].to(device) + torch.cuda.empty_cache() def get_model_with_model_path(self, model_path): for component in self.model_path: diff --git a/diffsynth/pipelines/stable_diffusion_video.py b/diffsynth/pipelines/stable_diffusion_video.py index 6b78261..fe876e6 100644 --- a/diffsynth/pipelines/stable_diffusion_video.py +++ b/diffsynth/pipelines/stable_diffusion_video.py @@ -51,6 +51,9 @@ def lets_dance_with_long_video( hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias)) hidden_states_output[i] = (hidden_states, num + 1) + if batch_id_ == num_frames: + break + # output hidden_states = torch.stack([h for h, _ in hidden_states_output]) return hidden_states @@ -195,10 +198,21 @@ class SDVideoPipeline(torch.nn.Module): # Prepare ControlNets if controlnet_frames is not None: - controlnet_frames = torch.stack([ - self.controlnet.process_image(controlnet_frame).to(self.torch_dtype) - for controlnet_frame in progress_bar_cmd(controlnet_frames) - ], dim=1) + if isinstance(controlnet_frames[0], list): + controlnet_frames_ = [] + for processor_id in range(len(controlnet_frames)): + controlnet_frames_.append( + torch.stack([ + self.controlnet.process_image(controlnet_frame, processor_id=processor_id).to(self.torch_dtype) + for controlnet_frame in progress_bar_cmd(controlnet_frames[processor_id]) + ], dim=1) + ) + controlnet_frames = torch.concat(controlnet_frames_, dim=0) + else: + controlnet_frames = torch.stack([ + self.controlnet.process_image(controlnet_frame).to(self.torch_dtype) + for controlnet_frame in progress_bar_cmd(controlnet_frames) + ], dim=1) # Denoise for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): diff --git a/environment.yml b/environment.yml index bb3b5af..eb0edf4 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,7 @@ dependencies: - pip=23.0.1 - cudatoolkit - pytorch + - cupy - pip: - transformers - controlnet-aux==0.0.7 diff --git a/pages/1_Image_Creator.py b/pages/1_Image_Creator.py index 8c735fa..9314f53 100644 --- a/pages/1_Image_Creator.py +++ b/pages/1_Image_Creator.py @@ -2,7 +2,6 @@ import torch, os, io import numpy as np from PIL import Image import streamlit as st -st.set_page_config(layout="wide") from streamlit_drawable_canvas import st_canvas from diffsynth.models import ModelManager from diffsynth.pipelines import SDImagePipeline, SDXLImagePipeline diff --git a/pages/2_Video_Creator.py b/pages/2_Video_Creator.py index 08e2f5f..ecb3bbc 100644 --- a/pages/2_Video_Creator.py +++ b/pages/2_Video_Creator.py @@ -1,4 +1,263 @@ import streamlit as st -st.set_page_config(layout="wide") +from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames +import torch, os, json +import numpy as np -st.markdown("# Coming soon") + +class Runner: + def __init__(self): + pass + + + def load_pipeline(self, model_list, textual_inversion_folder, device, controlnet_units): + # Load models + model_manager = ModelManager(torch_dtype=torch.float16, device=device) + model_manager.load_textual_inversions(textual_inversion_folder) + model_manager.load_models(model_list) + pipe = SDVideoPipeline.from_model_manager( + model_manager, + [ + ControlNetConfigUnit( + processor_id=unit["processor_id"], + model_path=unit["model_path"], + scale=unit["scale"] + ) for unit in controlnet_units + ] + ) + return model_manager, pipe + + + def synthesize_video(self, model_manager, pipe, seed, **pipeline_inputs): + torch.manual_seed(seed) + progress_bar_st = st.progress(0.0) + output_video = pipe(**pipeline_inputs, progress_bar_st=progress_bar_st) + progress_bar_st.progress(1.0) + model_manager.to("cpu") + return output_video + + + def load_video(self, video_file, image_folder, height, width, start_frame_id, end_frame_id): + video = VideoData(video_file=video_file, image_folder=image_folder, height=height, width=width) + frames = [video[i] for i in range(start_frame_id, end_frame_id)] + return frames + + + def add_data_to_pipeline_inputs(self, data, pipeline_inputs): + pipeline_inputs["input_frames"] = self.load_video(**data["input_frames"]) + pipeline_inputs["num_frames"] = len(pipeline_inputs["input_frames"]) + pipeline_inputs["width"], pipeline_inputs["height"] = pipeline_inputs["input_frames"][0].size + if len(data["controlnet_frames"]) > 0: + pipeline_inputs["controlnet_frames"] = [self.load_video(**unit) for unit in data["controlnet_frames"]] + return pipeline_inputs + + + def save_output(self, video, output_folder, fps, config): + os.makedirs(output_folder, exist_ok=True) + save_frames(video, os.path.join(output_folder, "frames")) + save_video(video, os.path.join(output_folder, "video.mp4"), fps=fps) + config["pipeline"]["pipeline_inputs"]["input_frames"] = [] + config["pipeline"]["pipeline_inputs"]["controlnet_frames"] = [] + with open(os.path.join(output_folder, "config.json"), 'w') as file: + json.dump(config, file, indent=4) + + + def run(self, config): + st.markdown("Loading videos ...") + config["pipeline"]["pipeline_inputs"] = self.add_data_to_pipeline_inputs(config["data"], config["pipeline"]["pipeline_inputs"]) + st.markdown("Loading videos ... done!") + st.markdown("Loading models ...") + model_manager, pipe = self.load_pipeline(**config["models"]) + st.markdown("Loading models ... done!") + st.markdown("Synthesizing videos ...") + output_video = self.synthesize_video(model_manager, pipe, config["pipeline"]["seed"], **config["pipeline"]["pipeline_inputs"]) + st.markdown("Synthesizing videos ... done!") + st.markdown("Saving videos ...") + self.save_output(output_video, config["data"]["output_folder"], config["data"]["fps"], config) + st.markdown("Saving videos ... done!") + st.markdown("Finished!") + video_file = open(os.path.join(os.path.join(config["data"]["output_folder"], "video.mp4")), 'rb') + st.video(video_file.read()) + + + +def load_model_list(folder): + file_list = os.listdir(folder) + file_list = [i for i in file_list if i.endswith(".safetensors") or i.endswith(".pth") or i.endswith(".ckpt")] + file_list = sorted(file_list) + return file_list + + +def match_processor_id(model_name, supported_processor_id_list): + sorted_processor_id = [i[1] for i in sorted([(-len(i), i) for i in supported_processor_id_list])] + for processor_id in sorted_processor_id: + if processor_id in model_name: + return supported_processor_id_list.index(processor_id) + 1 + return 0 + + +config = { + "models": { + "model_list": [], + "textual_inversion_folder": "models/textual_inversion", + "device": "cuda", + "controlnet_units": [] + }, + "data": { + "input_frames": None, + "controlnet_frames": [], + "output_folder": "output", + "fps": 60 + }, + "pipeline": { + "seed": 0, + "pipeline_inputs": {} + } +} + + +with st.expander("Model", expanded=True): + stable_diffusion_ckpt = st.selectbox("Stable Diffusion", ["None"] + load_model_list("models/stable_diffusion")) + if stable_diffusion_ckpt != "None": + config["models"]["model_list"].append(os.path.join("models/stable_diffusion", stable_diffusion_ckpt)) + animatediff_ckpt = st.selectbox("AnimateDiff", ["None"] + load_model_list("models/AnimateDiff")) + if animatediff_ckpt != "None": + config["models"]["model_list"].append(os.path.join("models/AnimateDiff", animatediff_ckpt)) + + +with st.expander("Data", expanded=True): + with st.container(border=True): + input_video = st.text_input("Input Video File Path (e.g., data/your_video.mp4)", value="") + column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1]) + with column_height: + height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024) + with column_width: + width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024) + with column_start_frame_index: + start_frame_id = st.number_input("Start Frame id", value=0) + with column_end_frame_index: + end_frame_id = st.number_input("End Frame id", value=16) + if input_video != "": + config["data"]["input_frames"] = { + "video_file": input_video, + "image_folder": None, + "height": height, + "width": width, + "start_frame_id": start_frame_id, + "end_frame_id": end_frame_id + } + with st.container(border=True): + output_video = st.text_input("Output Video File Path (e.g., data/a_folder_to_save_something)", value="output") + fps = st.number_input("FPS", value=60) + config["data"]["output_folder"] = output_video + config["data"]["fps"] = fps + + +with st.expander("ControlNet Units", expanded=True): + supported_processor_id_list = ["canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"] + controlnet_units = st.tabs(["ControlNet Unit 0", "ControlNet Unit 1", "ControlNet Unit 2"]) + for controlnet_id in range(len(controlnet_units)): + with controlnet_units[controlnet_id]: + controlnet_ckpt = st.selectbox("ControlNet", ["None"] + load_model_list("models/ControlNet"), + key=f"controlnet_ckpt_{controlnet_id}") + processor_id = st.selectbox("Processor", ["None"] + supported_processor_id_list, + index=match_processor_id(controlnet_ckpt, supported_processor_id_list), + disabled=controlnet_ckpt == "None", key=f"processor_id_{controlnet_id}") + controlnet_scale = st.slider("Scale", min_value=0.0, max_value=1.0, step=0.01, value=0.5, + disabled=controlnet_ckpt == "None", key=f"controlnet_scale_{controlnet_id}") + use_input_video_as_controlnet_input = st.checkbox("Use input video as ControlNet input", value=True, + disabled=controlnet_ckpt == "None", + key=f"use_input_video_as_controlnet_input_{controlnet_id}") + if not use_input_video_as_controlnet_input: + controlnet_input_video = st.text_input("ControlNet Input Video File Path", value="", + disabled=controlnet_ckpt == "None", key=f"controlnet_input_video_{controlnet_id}") + column_height, column_width, column_start_frame_index, column_end_frame_index = st.columns([2, 2, 1, 1]) + with column_height: + height = st.select_slider("Height", options=[256, 512, 768, 1024, 1536, 2048], value=1024, + disabled=controlnet_ckpt == "None", key=f"controlnet_height_{controlnet_id}") + with column_width: + width = st.select_slider("Width", options=[256, 512, 768, 1024, 1536, 2048], value=1024, + disabled=controlnet_ckpt == "None", key=f"controlnet_width_{controlnet_id}") + with column_start_frame_index: + start_frame_id = st.number_input("Start Frame id", value=0, + disabled=controlnet_ckpt == "None", key=f"controlnet_start_frame_id_{controlnet_id}") + with column_end_frame_index: + end_frame_id = st.number_input("End Frame id", value=16, + disabled=controlnet_ckpt == "None", key=f"controlnet_end_frame_id_{controlnet_id}") + if input_video != "": + config["data"]["input_video"] = { + "video_file": input_video, + "image_folder": None, + "height": height, + "width": width, + "start_frame_id": start_frame_id, + "end_frame_id": end_frame_id + } + if controlnet_ckpt != "None": + config["models"]["model_list"].append(os.path.join("models/ControlNet", controlnet_ckpt)) + config["models"]["controlnet_units"].append({ + "processor_id": processor_id, + "model_path": os.path.join("models/ControlNet", controlnet_ckpt), + "scale": controlnet_scale, + }) + if use_input_video_as_controlnet_input: + config["data"]["controlnet_frames"].append(config["data"]["input_frames"]) + else: + config["data"]["controlnet_frames"].append({ + "video_file": input_video, + "image_folder": None, + "height": height, + "width": width, + "start_frame_id": start_frame_id, + "end_frame_id": end_frame_id + }) + + +with st.container(border=True): + with st.expander("Seed", expanded=True): + use_fixed_seed = st.checkbox("Use fixed seed", value=False) + if use_fixed_seed: + seed = st.number_input("Random seed", min_value=0, max_value=10**9, step=1, value=0) + else: + seed = np.random.randint(0, 10**9) + with st.expander("Textual Guidance", expanded=True): + prompt = st.text_area("Positive prompt") + negative_prompt = st.text_area("Negative prompt") + column_cfg_scale, column_clip_skip = st.columns(2) + with column_cfg_scale: + cfg_scale = st.slider("Classifier-free guidance scale", min_value=1.0, max_value=10.0, value=7.0) + with column_clip_skip: + clip_skip = st.slider("Clip Skip", min_value=1, max_value=4, value=1) + with st.expander("Denoising", expanded=True): + column_num_inference_steps, column_denoising_strength = st.columns(2) + with column_num_inference_steps: + num_inference_steps = st.slider("Inference steps", min_value=1, max_value=100, value=10) + with column_denoising_strength: + denoising_strength = st.slider("Denoising strength", min_value=0.0, max_value=1.0, value=1.0) + with st.expander("Efficiency", expanded=False): + animatediff_batch_size = st.slider("Animatediff batch size (sliding window size)", min_value=1, max_value=32, value=16, step=1) + animatediff_stride = st.slider("Animatediff stride", + min_value=1, + max_value=max(2, animatediff_batch_size), + value=max(1, animatediff_batch_size // 2), + step=1) + unet_batch_size = st.slider("UNet batch size", min_value=1, max_value=32, value=1, step=1) + controlnet_batch_size = st.slider("ControlNet batch size", min_value=1, max_value=32, value=1, step=1) + cross_frame_attention = st.checkbox("Enable Cross-Frame Attention", value=False) + config["pipeline"]["seed"] = seed + config["pipeline"]["pipeline_inputs"] = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "cfg_scale": cfg_scale, + "clip_skip": clip_skip, + "denoising_strength": denoising_strength, + "num_inference_steps": num_inference_steps, + "animatediff_batch_size": animatediff_batch_size, + "animatediff_stride": animatediff_stride, + "unet_batch_size": unet_batch_size, + "controlnet_batch_size": controlnet_batch_size, + "cross_frame_attention": cross_frame_attention, + } + +run_button = st.button("☢️Run☢️", type="primary") +if run_button: + Runner().run(config) \ No newline at end of file