From d24ddaacaa5f9c67e5323ead04fec77a92af48ac Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Sat, 30 Dec 2023 21:01:24 +0800 Subject: [PATCH] v1.2 --- diffsynth/extensions/FastBlend/__init__.py | 63 +++ diffsynth/extensions/FastBlend/api.py | 397 ++++++++++++++++ .../extensions/FastBlend/cupy_kernels.py | 119 +++++ diffsynth/extensions/FastBlend/data.py | 146 ++++++ diffsynth/extensions/FastBlend/patch_match.py | 298 ++++++++++++ .../extensions/FastBlend/runners/__init__.py | 4 + .../extensions/FastBlend/runners/accurate.py | 35 ++ .../extensions/FastBlend/runners/balanced.py | 46 ++ .../extensions/FastBlend/runners/fast.py | 141 ++++++ .../FastBlend/runners/interpolation.py | 121 +++++ diffsynth/extensions/RIFE/__init__.py | 241 ++++++++++ diffsynth/models/__init__.py | 18 +- diffsynth/models/sd_unet.py | 37 +- diffsynth/models/svd_unet.py | 436 ++++++++++++++++++ diffsynth/pipelines/dancer.py | 9 +- diffsynth/pipelines/stable_diffusion_video.py | 39 +- examples/sd_text_to_video.py | 47 ++ examples/sd_toon_shading.py | 31 +- examples/sd_video_rerender.py | 58 +++ 19 files changed, 2252 insertions(+), 34 deletions(-) create mode 100644 diffsynth/extensions/FastBlend/__init__.py create mode 100644 diffsynth/extensions/FastBlend/api.py create mode 100644 diffsynth/extensions/FastBlend/cupy_kernels.py create mode 100644 diffsynth/extensions/FastBlend/data.py create mode 100644 diffsynth/extensions/FastBlend/patch_match.py create mode 100644 diffsynth/extensions/FastBlend/runners/__init__.py create mode 100644 diffsynth/extensions/FastBlend/runners/accurate.py create mode 100644 diffsynth/extensions/FastBlend/runners/balanced.py create mode 100644 diffsynth/extensions/FastBlend/runners/fast.py create mode 100644 diffsynth/extensions/FastBlend/runners/interpolation.py create mode 100644 diffsynth/extensions/RIFE/__init__.py create mode 100644 diffsynth/models/svd_unet.py create mode 100644 examples/sd_text_to_video.py create mode 100644 examples/sd_video_rerender.py diff --git a/diffsynth/extensions/FastBlend/__init__.py b/diffsynth/extensions/FastBlend/__init__.py new file mode 100644 index 0000000..5ff410a --- /dev/null +++ b/diffsynth/extensions/FastBlend/__init__.py @@ -0,0 +1,63 @@ +from .runners.fast import TableManager, PyramidPatchMatcher +from PIL import Image +import numpy as np +import cupy as cp + + +class FastBlendSmoother: + def __init__(self): + self.batch_size = 8 + self.window_size = 32 + self.ebsynth_config = { + "minimum_patch_size": 5, + "threads_per_block": 8, + "num_iter": 5, + "gpu_id": 0, + "guide_weight": 10.0, + "initialize": "identity", + "tracking_window_size": 0, + } + + @staticmethod + def from_model_manager(model_manager): + # TODO: fetch GPU ID from model_manager + return FastBlendSmoother() + + def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config): + frames_guide = [np.array(frame) for frame in frames_guide] + frames_style = [np.array(frame) for frame in frames_style] + table_manager = TableManager() + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + **ebsynth_config + ) + # left part + table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4") + table_l = table_manager.remapping_table_to_blending_table(table_l) + table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4") + # right part + table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4") + table_r = table_manager.remapping_table_to_blending_table(table_r) + table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1] + # merge + frames = [] + for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r): + weight_m = -1 + weight = weight_l + weight_m + weight_r + frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight) + frames.append(frame) + frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames] + return frames + + def __call__(self, rendered_frames, original_frames=None, **kwargs): + frames = self.run( + original_frames, rendered_frames, + self.batch_size, self.window_size, self.ebsynth_config + ) + mempool = cp.get_default_memory_pool() + pinned_mempool = cp.get_default_pinned_memory_pool() + mempool.free_all_blocks() + pinned_mempool.free_all_blocks() + return frames \ No newline at end of file diff --git a/diffsynth/extensions/FastBlend/api.py b/diffsynth/extensions/FastBlend/api.py new file mode 100644 index 0000000..2db2433 --- /dev/null +++ b/diffsynth/extensions/FastBlend/api.py @@ -0,0 +1,397 @@ +from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner +from .data import VideoData, get_video_fps, save_video, search_for_images +import os +import gradio as gr + + +def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder): + frames_guide = VideoData(video_guide, video_guide_folder) + frames_style = VideoData(video_style, video_style_folder) + message = "" + if len(frames_guide) < len(frames_style): + message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n" + frames_style.set_length(len(frames_guide)) + elif len(frames_guide) > len(frames_style): + message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n" + frames_guide.set_length(len(frames_style)) + height_guide, width_guide = frames_guide.shape() + height_style, width_style = frames_style.shape() + if height_guide != height_style or width_guide != width_style: + message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n" + frames_style.set_shape(height_guide, width_guide) + return frames_guide, frames_style, message + + +def smooth_video( + video_guide, + video_guide_folder, + video_style, + video_style_folder, + mode, + window_size, + batch_size, + tracking_window_size, + output_path, + fps, + minimum_patch_size, + num_iter, + guide_weight, + initialize, + progress = None, +): + # input + frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder) + if len(message) > 0: + print(message) + # output + if output_path == "": + if video_style is None: + output_path = os.path.join(video_style_folder, "output") + else: + output_path = os.path.join(os.path.split(video_style)[0], "output") + os.makedirs(output_path, exist_ok=True) + print("No valid output_path. Your video will be saved here:", output_path) + elif not os.path.exists(output_path): + os.makedirs(output_path, exist_ok=True) + print("Your video will be saved here:", output_path) + frames_path = os.path.join(output_path, "frames") + video_path = os.path.join(output_path, "video.mp4") + os.makedirs(frames_path, exist_ok=True) + # process + if mode == "Fast" or mode == "Balanced": + tracking_window_size = 0 + ebsynth_config = { + "minimum_patch_size": minimum_patch_size, + "threads_per_block": 8, + "num_iter": num_iter, + "gpu_id": 0, + "guide_weight": guide_weight, + "initialize": initialize, + "tracking_window_size": tracking_window_size, + } + if mode == "Fast": + FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path) + elif mode == "Balanced": + BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path) + elif mode == "Accurate": + AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path) + # output + try: + fps = int(fps) + except: + fps = get_video_fps(video_style) if video_style is not None else 30 + print("Fps:", fps) + print("Saving video...") + video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps) + print("Success!") + print("Your frames are here:", frames_path) + print("Your video is here:", video_path) + return output_path, fps, video_path + + +class KeyFrameMatcher: + def __init__(self): + pass + + def extract_number_from_filename(self, file_name): + result = [] + number = -1 + for i in file_name: + if ord(i)>=ord("0") and ord(i)<=ord("9"): + if number == -1: + number = 0 + number = number*10 + ord(i) - ord("0") + else: + if number != -1: + result.append(number) + number = -1 + if number != -1: + result.append(number) + result = tuple(result) + return result + + def extract_number_from_filenames(self, file_names): + numbers = [self.extract_number_from_filename(file_name) for file_name in file_names] + min_length = min(len(i) for i in numbers) + for i in range(min_length-1, -1, -1): + if len(set(number[i] for number in numbers))==len(file_names): + return [number[i] for number in numbers] + return list(range(len(file_names))) + + def match_using_filename(self, file_names_a, file_names_b): + file_names_b_set = set(file_names_b) + matched_file_name = [] + for file_name in file_names_a: + if file_name not in file_names_b_set: + matched_file_name.append(None) + else: + matched_file_name.append(file_name) + return matched_file_name + + def match_using_numbers(self, file_names_a, file_names_b): + numbers_a = self.extract_number_from_filenames(file_names_a) + numbers_b = self.extract_number_from_filenames(file_names_b) + numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)} + matched_file_name = [] + for number in numbers_a: + if number in numbers_b_dict: + matched_file_name.append(numbers_b_dict[number]) + else: + matched_file_name.append(None) + return matched_file_name + + def match_filenames(self, file_names_a, file_names_b): + matched_file_name = self.match_using_filename(file_names_a, file_names_b) + if sum([i is not None for i in matched_file_name]) > 0: + return matched_file_name + matched_file_name = self.match_using_numbers(file_names_a, file_names_b) + return matched_file_name + + +def detect_frames(frames_path, keyframes_path): + if not os.path.exists(frames_path) and not os.path.exists(keyframes_path): + return "Please input the directory of guide video and rendered frames" + elif not os.path.exists(frames_path): + return "Please input the directory of guide video" + elif not os.path.exists(keyframes_path): + return "Please input the directory of rendered frames" + frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)] + keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)] + if len(frames)==0: + return f"No images detected in {frames_path}" + if len(keyframes)==0: + return f"No images detected in {keyframes_path}" + matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes) + max_filename_length = max([len(i) for i in frames]) + if sum([i is not None for i in matched_keyframes])==0: + message = "" + for frame, matched_keyframe in zip(frames, matched_keyframes): + message += frame + " " * (max_filename_length - len(frame) + 1) + message += "--> No matched keyframes\n" + else: + message = "" + for frame, matched_keyframe in zip(frames, matched_keyframes): + message += frame + " " * (max_filename_length - len(frame) + 1) + if matched_keyframe is None: + message += "--> [to be rendered]\n" + else: + message += f"--> {matched_keyframe}\n" + return message + + +def check_input_for_interpolating(frames_path, keyframes_path): + # search for images + frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)] + keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)] + # match frames + matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes) + file_list = [file_name for file_name in matched_keyframes if file_name is not None] + index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None] + frames_guide = VideoData(None, frames_path) + frames_style = VideoData(None, keyframes_path, file_list=file_list) + # match shape + message = "" + height_guide, width_guide = frames_guide.shape() + height_style, width_style = frames_style.shape() + if height_guide != height_style or width_guide != width_style: + message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n" + frames_style.set_shape(height_guide, width_guide) + return frames_guide, frames_style, index_style, message + + +def interpolate_video( + frames_path, + keyframes_path, + output_path, + fps, + batch_size, + tracking_window_size, + minimum_patch_size, + num_iter, + guide_weight, + initialize, + progress = None, +): + # input + frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path) + if len(message) > 0: + print(message) + # output + if output_path == "": + output_path = os.path.join(keyframes_path, "output") + os.makedirs(output_path, exist_ok=True) + print("No valid output_path. Your video will be saved here:", output_path) + elif not os.path.exists(output_path): + os.makedirs(output_path, exist_ok=True) + print("Your video will be saved here:", output_path) + output_frames_path = os.path.join(output_path, "frames") + output_video_path = os.path.join(output_path, "video.mp4") + os.makedirs(output_frames_path, exist_ok=True) + # process + ebsynth_config = { + "minimum_patch_size": minimum_patch_size, + "threads_per_block": 8, + "num_iter": num_iter, + "gpu_id": 0, + "guide_weight": guide_weight, + "initialize": initialize, + "tracking_window_size": tracking_window_size + } + if len(index_style)==1: + InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path) + else: + InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path) + try: + fps = int(fps) + except: + fps = 30 + print("Fps:", fps) + print("Saving video...") + video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps) + print("Success!") + print("Your frames are here:", output_frames_path) + print("Your video is here:", video_path) + return output_path, fps, video_path + + +def on_ui_tabs(): + with gr.Blocks(analytics_enabled=False) as ui_component: + with gr.Tab("Blend"): + gr.Markdown(""" +# Blend + +Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution. + """) + with gr.Row(): + with gr.Column(): + with gr.Tab("Guide video"): + video_guide = gr.Video(label="Guide video") + with gr.Tab("Guide video (images format)"): + video_guide_folder = gr.Textbox(label="Guide video (images format)", value="") + with gr.Column(): + with gr.Tab("Style video"): + video_style = gr.Video(label="Style video") + with gr.Tab("Style video (images format)"): + video_style_folder = gr.Textbox(label="Style video (images format)", value="") + with gr.Column(): + output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video") + fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps") + video_output = gr.Video(label="Output video", interactive=False, show_share_button=True) + btn = gr.Button(value="Blend") + with gr.Row(): + with gr.Column(): + gr.Markdown("# Settings") + mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True) + window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True) + batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True) + tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True) + gr.Markdown("## Advanced Settings") + minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True) + num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True) + guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True) + initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True) + with gr.Column(): + gr.Markdown(""" +# Reference + +* Output directory: the directory to save the video. +* Inference mode + +|Mode|Time|Memory|Quality|Frame by frame output|Description| +|-|-|-|-|-|-| +|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.| +|Balanced|■■|■|■■|Yes|Blend the frames naively.| +|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.| + +* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy. +* Batch size: a larger batch size makes the program faster but requires more VRAM. +* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough. +* Advanced settings + * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5) + * Number of iterations: the number of iterations of patch matching. (Default: 5) + * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10) + * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity) + """) + btn.click( + smooth_video, + inputs=[ + video_guide, + video_guide_folder, + video_style, + video_style_folder, + mode, + window_size, + batch_size, + tracking_window_size, + output_path, + fps, + minimum_patch_size, + num_iter, + guide_weight, + initialize + ], + outputs=[output_path, fps, video_output] + ) + with gr.Tab("Interpolate"): + gr.Markdown(""" +# Interpolate + +Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution. + """) + with gr.Row(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="") + with gr.Column(): + rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="") + with gr.Row(): + detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False) + video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames) + rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames) + with gr.Column(): + output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes") + fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps") + video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True) + btn_ = gr.Button(value="Interpolate") + with gr.Row(): + with gr.Column(): + gr.Markdown("# Settings") + batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True) + tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True) + gr.Markdown("## Advanced Settings") + minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True) + num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True) + guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True) + initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True) + with gr.Column(): + gr.Markdown(""" +# Reference + +* Output directory: the directory to save the video. +* Batch size: a larger batch size makes the program faster but requires more VRAM. +* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough. +* Advanced settings + * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)** + * Number of iterations: the number of iterations of patch matching. (Default: 5) + * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10) + * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity) + """) + btn_.click( + interpolate_video, + inputs=[ + video_guide_folder_, + rendered_keyframes_, + output_path_, + fps_, + batch_size_, + tracking_window_size_, + minimum_patch_size_, + num_iter_, + guide_weight_, + initialize_, + ], + outputs=[output_path_, fps_, video_output_] + ) + + return [(ui_component, "FastBlend", "FastBlend_ui")] diff --git a/diffsynth/extensions/FastBlend/cupy_kernels.py b/diffsynth/extensions/FastBlend/cupy_kernels.py new file mode 100644 index 0000000..70e2790 --- /dev/null +++ b/diffsynth/extensions/FastBlend/cupy_kernels.py @@ -0,0 +1,119 @@ +import cupy as cp + +remapping_kernel = cp.RawKernel(r''' +extern "C" __global__ +void remap( + const int height, + const int width, + const int channel, + const int patch_size, + const int pad_size, + const float* source_style, + const int* nnf, + float* target_style +) { + const int r = (patch_size - 1) / 2; + const int x = blockDim.x * blockIdx.x + threadIdx.x; + const int y = blockDim.y * blockIdx.y + threadIdx.y; + if (x >= height or y >= width) return; + const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; + const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size); + const int min_px = x < r ? -x : -r; + const int max_px = x + r > height - 1 ? height - 1 - x : r; + const int min_py = y < r ? -y : -r; + const int max_py = y + r > width - 1 ? width - 1 - y : r; + int num = 0; + for (int px = min_px; px <= max_px; px++){ + for (int py = min_py; py <= max_py; py++){ + const int nid = (x + px) * width + y + py; + const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px; + const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py; + if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue; + const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size); + num++; + for (int c = 0; c < channel; c++){ + target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c]; + } + } + } + for (int c = 0; c < channel; c++){ + target_style[z + pid * channel + c] /= num; + } +} +''', 'remap') + + +patch_error_kernel = cp.RawKernel(r''' +extern "C" __global__ +void patch_error( + const int height, + const int width, + const int channel, + const int patch_size, + const int pad_size, + const float* source, + const int* nnf, + const float* target, + float* error +) { + const int r = (patch_size - 1) / 2; + const int x = blockDim.x * blockIdx.x + threadIdx.x; + const int y = blockDim.y * blockIdx.y + threadIdx.y; + const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; + if (x >= height or y >= width) return; + const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0]; + const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1]; + float e = 0; + for (int px = -r; px <= r; px++){ + for (int py = -r; py <= r; py++){ + const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py; + const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py; + for (int c = 0; c < channel; c++){ + const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c]; + e += diff * diff; + } + } + } + error[blockIdx.z * height * width + x * width + y] = e; +} +''', 'patch_error') + + +pairwise_patch_error_kernel = cp.RawKernel(r''' +extern "C" __global__ +void pairwise_patch_error( + const int height, + const int width, + const int channel, + const int patch_size, + const int pad_size, + const float* source_a, + const int* nnf_a, + const float* source_b, + const int* nnf_b, + float* error +) { + const int r = (patch_size - 1) / 2; + const int x = blockDim.x * blockIdx.x + threadIdx.x; + const int y = blockDim.y * blockIdx.y + threadIdx.y; + const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; + if (x >= height or y >= width) return; + const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2; + const int x_a = nnf_a[z_nnf + 0]; + const int y_a = nnf_a[z_nnf + 1]; + const int x_b = nnf_b[z_nnf + 0]; + const int y_b = nnf_b[z_nnf + 1]; + float e = 0; + for (int px = -r; px <= r; px++){ + for (int py = -r; py <= r; py++){ + const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py; + const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py; + for (int c = 0; c < channel; c++){ + const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c]; + e += diff * diff; + } + } + } + error[blockIdx.z * height * width + x * width + y] = e; +} +''', 'pairwise_patch_error') diff --git a/diffsynth/extensions/FastBlend/data.py b/diffsynth/extensions/FastBlend/data.py new file mode 100644 index 0000000..dcaddd7 --- /dev/null +++ b/diffsynth/extensions/FastBlend/data.py @@ -0,0 +1,146 @@ +import imageio, os +import numpy as np +from PIL import Image + + +def read_video(file_name): + reader = imageio.get_reader(file_name) + video = [] + for frame in reader: + frame = np.array(frame) + video.append(frame) + reader.close() + return video + + +def get_video_fps(file_name): + reader = imageio.get_reader(file_name) + fps = reader.get_meta_data()["fps"] + reader.close() + return fps + + +def save_video(frames_path, video_path, num_frames, fps): + writer = imageio.get_writer(video_path, fps=fps, quality=9) + for i in range(num_frames): + frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i))) + writer.append_data(frame) + writer.close() + return video_path + + +class LowMemoryVideo: + def __init__(self, file_name): + self.reader = imageio.get_reader(file_name) + + def __len__(self): + return self.reader.count_frames() + + def __getitem__(self, item): + return np.array(self.reader.get_data(item)) + + def __del__(self): + self.reader.close() + + +def split_file_name(file_name): + result = [] + number = -1 + for i in file_name: + if ord(i)>=ord("0") and ord(i)<=ord("9"): + if number == -1: + number = 0 + number = number*10 + ord(i) - ord("0") + else: + if number != -1: + result.append(number) + number = -1 + result.append(i) + if number != -1: + result.append(number) + result = tuple(result) + return result + + +def search_for_images(folder): + file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")] + file_list = [(split_file_name(file_name), file_name) for file_name in file_list] + file_list = [i[1] for i in sorted(file_list)] + file_list = [os.path.join(folder, i) for i in file_list] + return file_list + + +def read_images(folder): + file_list = search_for_images(folder) + frames = [np.array(Image.open(i)) for i in file_list] + return frames + + +class LowMemoryImageFolder: + def __init__(self, folder, file_list=None): + if file_list is None: + self.file_list = search_for_images(folder) + else: + self.file_list = [os.path.join(folder, file_name) for file_name in file_list] + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, item): + return np.array(Image.open(self.file_list[item])) + + def __del__(self): + pass + + +class VideoData: + def __init__(self, video_file, image_folder, **kwargs): + if video_file is not None: + self.data_type = "video" + self.data = LowMemoryVideo(video_file, **kwargs) + elif image_folder is not None: + self.data_type = "images" + self.data = LowMemoryImageFolder(image_folder, **kwargs) + else: + raise ValueError("Cannot open video or image folder") + self.length = None + self.height = None + self.width = None + + def raw_data(self): + frames = [] + for i in range(self.__len__()): + frames.append(self.__getitem__(i)) + return frames + + def set_length(self, length): + self.length = length + + def set_shape(self, height, width): + self.height = height + self.width = width + + def __len__(self): + if self.length is None: + return len(self.data) + else: + return self.length + + def shape(self): + if self.height is not None and self.width is not None: + return self.height, self.width + else: + height, width, _ = self.__getitem__(0).shape + return height, width + + def __getitem__(self, item): + frame = self.data.__getitem__(item) + height, width, _ = frame.shape + if self.height is not None and self.width is not None: + if self.height != height or self.width != width: + frame = Image.fromarray(frame).resize((self.width, self.height)) + frame = np.array(frame) + return frame + + def __del__(self): + pass diff --git a/diffsynth/extensions/FastBlend/patch_match.py b/diffsynth/extensions/FastBlend/patch_match.py new file mode 100644 index 0000000..aeb1f7f --- /dev/null +++ b/diffsynth/extensions/FastBlend/patch_match.py @@ -0,0 +1,298 @@ +from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel +import numpy as np +import cupy as cp +import cv2 + + +class PatchMatcher: + def __init__( + self, height, width, channel, minimum_patch_size, + threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, + random_search_steps=3, random_search_range=4, + use_mean_target_style=False, use_pairwise_patch_error=False, + tracking_window_size=0 + ): + self.height = height + self.width = width + self.channel = channel + self.minimum_patch_size = minimum_patch_size + self.threads_per_block = threads_per_block + self.num_iter = num_iter + self.gpu_id = gpu_id + self.guide_weight = guide_weight + self.random_search_steps = random_search_steps + self.random_search_range = random_search_range + self.use_mean_target_style = use_mean_target_style + self.use_pairwise_patch_error = use_pairwise_patch_error + self.tracking_window_size = tracking_window_size + + self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1] + self.pad_size = self.patch_size_list[0] // 2 + self.grid = ( + (height + threads_per_block - 1) // threads_per_block, + (width + threads_per_block - 1) // threads_per_block + ) + self.block = (threads_per_block, threads_per_block) + + def pad_image(self, image): + return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0))) + + def unpad_image(self, image): + return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :] + + def apply_nnf_to_image(self, nnf, source): + batch_size = source.shape[0] + target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32) + remapping_kernel( + self.grid + (batch_size,), + self.block, + (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target) + ) + return target + + def get_patch_error(self, source, nnf, target): + batch_size = source.shape[0] + error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32) + patch_error_kernel( + self.grid + (batch_size,), + self.block, + (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error) + ) + return error + + def get_pairwise_patch_error(self, source, nnf): + batch_size = source.shape[0]//2 + error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32) + source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy() + source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy() + pairwise_patch_error_kernel( + self.grid + (batch_size,), + self.block, + (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error) + ) + error = error.repeat(2, axis=0) + return error + + def get_error(self, source_guide, target_guide, source_style, target_style, nnf): + error_guide = self.get_patch_error(source_guide, nnf, target_guide) + if self.use_mean_target_style: + target_style = self.apply_nnf_to_image(nnf, source_style) + target_style = target_style.mean(axis=0, keepdims=True) + target_style = target_style.repeat(source_guide.shape[0], axis=0) + if self.use_pairwise_patch_error: + error_style = self.get_pairwise_patch_error(source_style, nnf) + else: + error_style = self.get_patch_error(source_style, nnf, target_style) + error = error_guide * self.guide_weight + error_style + return error + + def clamp_bound(self, nnf): + nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1) + nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1) + return nnf + + def random_step(self, nnf, r): + batch_size = nnf.shape[0] + step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32) + upd_nnf = self.clamp_bound(nnf + step) + return upd_nnf + + def neighboor_step(self, nnf, d): + if d==0: + upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1) + upd_nnf[:, :, :, 0] += 1 + elif d==1: + upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2) + upd_nnf[:, :, :, 1] += 1 + elif d==2: + upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1) + upd_nnf[:, :, :, 0] -= 1 + elif d==3: + upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2) + upd_nnf[:, :, :, 1] -= 1 + upd_nnf = self.clamp_bound(upd_nnf) + return upd_nnf + + def shift_nnf(self, nnf, d): + if d>0: + d = min(nnf.shape[0], d) + upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0) + else: + d = max(-nnf.shape[0], d) + upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0) + return upd_nnf + + def track_step(self, nnf, d): + if self.use_pairwise_patch_error: + upd_nnf = cp.zeros_like(nnf) + upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d) + upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d) + else: + upd_nnf = self.shift_nnf(nnf, d) + return upd_nnf + + def C(self, n, m): + # not used + c = 1 + for i in range(1, n+1): + c *= i + for i in range(1, m+1): + c //= i + for i in range(1, n-m+1): + c //= i + return c + + def bezier_step(self, nnf, r): + # not used + n = r * 2 - 1 + upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32) + for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))): + if d>0: + ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0) + elif d<0: + ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0) + upd_nnf += ctl_nnf * (self.C(n, i) / 2**n) + upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype) + return upd_nnf + + def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf): + upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf) + upd_idx = (upd_err < err) + nnf[upd_idx] = upd_nnf[upd_idx] + err[upd_idx] = upd_err[upd_idx] + return nnf, err + + def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err): + for d in cp.random.permutation(4): + upd_nnf = self.neighboor_step(nnf, d) + nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf) + return nnf, err + + def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err): + for i in range(self.random_search_steps): + upd_nnf = self.random_step(nnf, self.random_search_range) + nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf) + return nnf, err + + def track(self, source_guide, target_guide, source_style, target_style, nnf, err): + for d in range(1, self.tracking_window_size + 1): + upd_nnf = self.track_step(nnf, d) + nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf) + upd_nnf = self.track_step(nnf, -d) + nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf) + return nnf, err + + def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err): + nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err) + nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err) + nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err) + return nnf, err + + def estimate_nnf(self, source_guide, target_guide, source_style, nnf): + with cp.cuda.Device(self.gpu_id): + source_guide = self.pad_image(source_guide) + target_guide = self.pad_image(target_guide) + source_style = self.pad_image(source_style) + for it in range(self.num_iter): + self.patch_size = self.patch_size_list[it] + target_style = self.apply_nnf_to_image(nnf, source_style) + err = self.get_error(source_guide, target_guide, source_style, target_style, nnf) + nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err) + target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style)) + return nnf, target_style + + +class PyramidPatchMatcher: + def __init__( + self, image_height, image_width, channel, minimum_patch_size, + threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, + use_mean_target_style=False, use_pairwise_patch_error=False, + tracking_window_size=0, + initialize="identity" + ): + maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2 + self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size)) + self.pyramid_heights = [] + self.pyramid_widths = [] + self.patch_matchers = [] + self.minimum_patch_size = minimum_patch_size + self.num_iter = num_iter + self.gpu_id = gpu_id + self.initialize = initialize + for level in range(self.pyramid_level): + height = image_height//(2**(self.pyramid_level - 1 - level)) + width = image_width//(2**(self.pyramid_level - 1 - level)) + self.pyramid_heights.append(height) + self.pyramid_widths.append(width) + self.patch_matchers.append(PatchMatcher( + height, width, channel, minimum_patch_size=minimum_patch_size, + threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight, + use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error, + tracking_window_size=tracking_window_size + )) + + def resample_image(self, images, level): + height, width = self.pyramid_heights[level], self.pyramid_widths[level] + images = images.get() + images_resample = [] + for image in images: + image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA) + images_resample.append(image_resample) + images_resample = cp.array(np.stack(images_resample), dtype=cp.float32) + return images_resample + + def initialize_nnf(self, batch_size): + if self.initialize == "random": + height, width = self.pyramid_heights[0], self.pyramid_widths[0] + nnf = cp.stack([ + cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32), + cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32) + ], axis=3) + elif self.initialize == "identity": + height, width = self.pyramid_heights[0], self.pyramid_widths[0] + nnf = cp.stack([ + cp.repeat(cp.arange(height), width).reshape(height, width), + cp.tile(cp.arange(width), height).reshape(height, width) + ], axis=2) + nnf = cp.stack([nnf] * batch_size) + else: + raise NotImplementedError() + return nnf + + def update_nnf(self, nnf, level): + # upscale + nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2 + nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1 + nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1 + # check if scale is 2 + height, width = self.pyramid_heights[level], self.pyramid_widths[level] + if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2: + nnf = nnf.get().astype(np.float32) + nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf] + nnf = cp.array(np.stack(nnf), dtype=cp.int32) + nnf = self.patch_matchers[level].clamp_bound(nnf) + return nnf + + def apply_nnf_to_image(self, nnf, image): + with cp.cuda.Device(self.gpu_id): + image = self.patch_matchers[-1].pad_image(image) + image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image) + return image + + def estimate_nnf(self, source_guide, target_guide, source_style): + with cp.cuda.Device(self.gpu_id): + if not isinstance(source_guide, cp.ndarray): + source_guide = cp.array(source_guide, dtype=cp.float32) + if not isinstance(target_guide, cp.ndarray): + target_guide = cp.array(target_guide, dtype=cp.float32) + if not isinstance(source_style, cp.ndarray): + source_style = cp.array(source_style, dtype=cp.float32) + for level in range(self.pyramid_level): + nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level) + source_guide_ = self.resample_image(source_guide, level) + target_guide_ = self.resample_image(target_guide, level) + source_style_ = self.resample_image(source_style, level) + nnf, target_style = self.patch_matchers[level].estimate_nnf( + source_guide_, target_guide_, source_style_, nnf + ) + return nnf.get(), target_style.get() diff --git a/diffsynth/extensions/FastBlend/runners/__init__.py b/diffsynth/extensions/FastBlend/runners/__init__.py new file mode 100644 index 0000000..0783827 --- /dev/null +++ b/diffsynth/extensions/FastBlend/runners/__init__.py @@ -0,0 +1,4 @@ +from .accurate import AccurateModeRunner +from .fast import FastModeRunner +from .balanced import BalancedModeRunner +from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner diff --git a/diffsynth/extensions/FastBlend/runners/accurate.py b/diffsynth/extensions/FastBlend/runners/accurate.py new file mode 100644 index 0000000..2e4a47f --- /dev/null +++ b/diffsynth/extensions/FastBlend/runners/accurate.py @@ -0,0 +1,35 @@ +from ..patch_match import PyramidPatchMatcher +import os +import numpy as np +from PIL import Image +from tqdm import tqdm + + +class AccurateModeRunner: + def __init__(self): + pass + + def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None): + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + use_mean_target_style=True, + **ebsynth_config + ) + # run + n = len(frames_style) + for target in tqdm(range(n), desc=desc): + l, r = max(target - window_size, 0), min(target + window_size + 1, n) + remapped_frames = [] + for i in range(l, r, batch_size): + j = min(i + batch_size, r) + source_guide = np.stack([frames_guide[source] for source in range(i, j)]) + target_guide = np.stack([frames_guide[target]] * (j - i)) + source_style = np.stack([frames_style[source] for source in range(i, j)]) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + remapped_frames.append(target_style) + frame = np.concatenate(remapped_frames, axis=0).mean(axis=0) + frame = frame.clip(0, 255).astype("uint8") + if save_path is not None: + Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target)) \ No newline at end of file diff --git a/diffsynth/extensions/FastBlend/runners/balanced.py b/diffsynth/extensions/FastBlend/runners/balanced.py new file mode 100644 index 0000000..1c9a2bb --- /dev/null +++ b/diffsynth/extensions/FastBlend/runners/balanced.py @@ -0,0 +1,46 @@ +from ..patch_match import PyramidPatchMatcher +import os +import numpy as np +from PIL import Image +from tqdm import tqdm + + +class BalancedModeRunner: + def __init__(self): + pass + + def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None): + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + **ebsynth_config + ) + # tasks + n = len(frames_style) + tasks = [] + for target in range(n): + for source in range(target - window_size, target + window_size + 1): + if source >= 0 and source < n and source != target: + tasks.append((source, target)) + # run + frames = [(None, 1) for i in range(n)] + for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc): + tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] + source_guide = np.stack([frames_guide[source] for source, target in tasks_batch]) + target_guide = np.stack([frames_guide[target] for source, target in tasks_batch]) + source_style = np.stack([frames_style[source] for source, target in tasks_batch]) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + for (source, target), result in zip(tasks_batch, target_style): + frame, weight = frames[target] + if frame is None: + frame = frames_style[target] + frames[target] = ( + frame * (weight / (weight + 1)) + result / (weight + 1), + weight + 1 + ) + if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size): + frame = frame.clip(0, 255).astype("uint8") + if save_path is not None: + Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target)) + frames[target] = (None, 1) diff --git a/diffsynth/extensions/FastBlend/runners/fast.py b/diffsynth/extensions/FastBlend/runners/fast.py new file mode 100644 index 0000000..2ba5731 --- /dev/null +++ b/diffsynth/extensions/FastBlend/runners/fast.py @@ -0,0 +1,141 @@ +from ..patch_match import PyramidPatchMatcher +import functools, os +import numpy as np +from PIL import Image +from tqdm import tqdm + + +class TableManager: + def __init__(self): + pass + + def task_list(self, n): + tasks = [] + max_level = 1 + while (1<=n: + break + meta_data = { + "source": i, + "target": j, + "level": level + 1 + } + tasks.append(meta_data) + tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"])) + return tasks + + def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""): + n = len(frames_guide) + tasks = self.task_list(n) + remapping_table = [[(frames_style[i], 1)] for i in range(n)] + for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc): + tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] + source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch]) + target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch]) + source_style = np.stack([frames_style[task["source"]] for task in tasks_batch]) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + for task, result in zip(tasks_batch, target_style): + target, level = task["target"], task["level"] + if len(remapping_table[target])==level: + remapping_table[target].append((result, 1)) + else: + frame, weight = remapping_table[target][level] + remapping_table[target][level] = ( + frame * (weight / (weight + 1)) + result / (weight + 1), + weight + 1 + ) + return remapping_table + + def remapping_table_to_blending_table(self, table): + for i in range(len(table)): + for j in range(1, len(table[i])): + frame_1, weight_1 = table[i][j-1] + frame_2, weight_2 = table[i][j] + frame = (frame_1 + frame_2) / 2 + weight = weight_1 + weight_2 + table[i][j] = (frame, weight) + return table + + def tree_query(self, leftbound, rightbound): + node_list = [] + node_index = rightbound + while node_index>=leftbound: + node_level = 0 + while (1<=leftbound: + node_level += 1 + node_list.append((node_index, node_level)) + node_index -= 1<0: + tasks = [] + for m in range(index_style[0]): + tasks.append((index_style[0], m, index_style[0])) + task_group.append(tasks) + # middle frames + for l, r in zip(index_style[:-1], index_style[1:]): + tasks = [] + for m in range(l, r): + tasks.append((l, m, r)) + task_group.append(tasks) + # last frame + tasks = [] + for m in range(index_style[-1], n): + tasks.append((index_style[-1], m, index_style[-1])) + task_group.append(tasks) + return task_group + + def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None): + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + use_mean_target_style=False, + use_pairwise_patch_error=True, + **ebsynth_config + ) + # task + index_dict = self.get_index_dict(index_style) + task_group = self.get_task_group(index_style, len(frames_guide)) + # run + for tasks in task_group: + index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks]) + for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"): + tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))] + source_guide, target_guide, source_style = [], [], [] + for l, m, r in tasks_batch: + # l -> m + source_guide.append(frames_guide[l]) + target_guide.append(frames_guide[m]) + source_style.append(frames_style[index_dict[l]]) + # r -> m + source_guide.append(frames_guide[r]) + target_guide.append(frames_guide[m]) + source_style.append(frames_style[index_dict[r]]) + source_guide = np.stack(source_guide) + target_guide = np.stack(target_guide) + source_style = np.stack(source_style) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + if save_path is not None: + for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch): + weight_l, weight_r = self.get_weight(l, m, r) + frame = frame_l * weight_l + frame_r * weight_r + frame = frame.clip(0, 255).astype("uint8") + Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m)) + + +class InterpolationModeSingleFrameRunner: + def __init__(self): + pass + + def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None): + # check input + tracking_window_size = ebsynth_config["tracking_window_size"] + if tracking_window_size * 2 >= batch_size: + raise ValueError("batch_size should be larger than track_window_size * 2") + frame_style = frames_style[0] + frame_guide = frames_guide[index_style[0]] + patch_match_engine = PyramidPatchMatcher( + image_height=frame_style.shape[0], + image_width=frame_style.shape[1], + channel=3, + **ebsynth_config + ) + # run + frame_id, n = 0, len(frames_guide) + for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"): + if i + batch_size > n: + l, r = max(n - batch_size, 0), n + else: + l, r = i, i + batch_size + source_guide = np.stack([frame_guide] * (r-l)) + target_guide = np.stack([frames_guide[i] for i in range(l, r)]) + source_style = np.stack([frame_style] * (r-l)) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + for i, frame in zip(range(l, r), target_style): + if i==frame_id: + frame = frame.clip(0, 255).astype("uint8") + Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id)) + frame_id += 1 + if r < n and r-frame_id <= tracking_window_size: + break diff --git a/diffsynth/extensions/RIFE/__init__.py b/diffsynth/extensions/RIFE/__init__.py new file mode 100644 index 0000000..421def9 --- /dev/null +++ b/diffsynth/extensions/RIFE/__init__.py @@ -0,0 +1,241 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from PIL import Image + + +def warp(tenInput, tenFlow, device): + backwarp_tenGrid = {} + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( + 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat( + [tenHorizontal, tenVertical], 1).to(device) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.PReLU(out_planes) + ) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),) + self.convblock0 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock1 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock2 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock3 = nn.Sequential(conv(c, c), conv(c, c)) + self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1)) + self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1)) + + def forward(self, x, flow, scale=1): + x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) + flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale + feat = self.conv0(torch.cat((x, flow), 1)) + feat = self.convblock0(feat) + feat + feat = self.convblock1(feat) + feat + feat = self.convblock2(feat) + feat + feat = self.convblock3(feat) + feat + flow = self.conv1(feat) + mask = self.conv2(feat) + flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale + mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) + return flow, mask + + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7+4, c=90) + self.block1 = IFBlock(7+4, c=90) + self.block2 = IFBlock(7+4, c=90) + self.block_tea = IFBlock(10+4, c=90) + + def forward(self, x, scale_list=[4, 2, 1], training=False): + if training == False: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = (x[:, :4]).detach() * 0 + mask = (x[:, :1]).detach() * 0 + block = [self.block0, self.block1, self.block2] + for i in range(3): + f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) + f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i]) + flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 + mask = mask + (m0 + (-m1)) / 2 + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2], device=x.device) + warped_img1 = warp(img1, flow[:, 2:4], device=x.device) + merged.append((warped_img0, warped_img1)) + ''' + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, 1:4] * 2 - 1 + ''' + for i in range(3): + mask_list[i] = torch.sigmoid(mask_list[i]) + merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) + return flow_list, mask_list[2], merged + + def state_dict_converter(self): + return IFNetStateDictConverter() + + +class IFNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()} + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) + + +class RIFEInterpolater: + def __init__(self, model, device="cuda"): + self.model = model + self.device = device + # IFNet only does not support float16 + self.torch_dtype = torch.float32 + + @staticmethod + def from_model_manager(model_manager): + return RIFEInterpolater(model_manager.RIFE, device=model_manager.device) + + def process_image(self, image): + width, height = image.size + if width % 32 != 0 or height % 32 != 0: + width = (width + 31) // 32 + height = (height + 31) // 32 + image = image.resize((width, height)) + image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1) + return image + + def process_images(self, images): + images = [self.process_image(image) for image in images] + images = torch.stack(images) + return images + + def decode_images(self, images): + images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) + images = [Image.fromarray(image) for image in images] + return images + + def add_interpolated_images(self, images, interpolated_images): + output_images = [] + for image, interpolated_image in zip(images, interpolated_images): + output_images.append(image) + output_images.append(interpolated_image) + output_images.append(images[-1]) + return output_images + + + @torch.no_grad() + def interpolate_(self, images, scale=1.0): + input_tensor = self.process_images(images) + input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1) + input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype) + flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale]) + output_images = self.decode_images(merged[2].cpu()) + if output_images[0].size != images[0].size: + output_images = [image.resize(images[0].size) for image in output_images] + return output_images + + + @torch.no_grad() + def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1): + # Preprocess + processed_images = self.process_images(images) + + for iter in range(num_iter): + # Input + input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1) + + # Interpolate + output_tensor = [] + for batch_id in range(0, input_tensor.shape[0], batch_size): + batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) + batch_input_tensor = input_tensor[batch_id: batch_id_] + batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) + flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) + output_tensor.append(merged[2].cpu()) + + # Output + output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1) + processed_images = self.add_interpolated_images(processed_images, output_tensor) + processed_images = torch.stack(processed_images) + + # To images + output_images = self.decode_images(processed_images) + if output_images[0].size != images[0].size: + output_images = [image.resize(images[0].size) for image in output_images] + return output_images + + +class RIFESmoother(RIFEInterpolater): + def __init__(self, model, device="cuda"): + super(RIFESmoother, self).__init__(model, device=device) + + @staticmethod + def from_model_manager(model_manager): + return RIFESmoother(model_manager.RIFE, device=model_manager.device) + + def process_tensors(self, input_tensor, scale=1.0, batch_size=4): + output_tensor = [] + for batch_id in range(0, input_tensor.shape[0], batch_size): + batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) + batch_input_tensor = input_tensor[batch_id: batch_id_] + batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) + flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) + output_tensor.append(merged[2].cpu()) + output_tensor = torch.concat(output_tensor, dim=0) + return output_tensor + + @torch.no_grad() + def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs): + # Preprocess + processed_images = self.process_images(rendered_frames) + + for iter in range(num_iter): + # Input + input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1) + + # Interpolate + output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size) + + # Blend + input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1) + output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size) + + # Add to frames + processed_images[1:-1] = output_tensor + + # To images + output_images = self.decode_images(processed_images) + if output_images[0].size != rendered_frames[0].size: + output_images = [image.resize(rendered_frames[0].size) for image in output_images] + return output_images diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index dcbe1d7..899d0b5 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -15,8 +15,6 @@ from .sd_controlnet import SDControlNet from .sd_motion import SDMotionModel -from transformers import AutoModelForCausalLM - class ModelManager: def __init__(self, torch_dtype=torch.float16, device="cuda"): @@ -26,6 +24,10 @@ class ModelManager: self.model_path = {} self.textual_inversion_dict = {} + def is_RIFE(self, state_dict): + param_name = "block_tea.convblock3.0.1.weight" + return param_name in state_dict or ("module." + param_name) in state_dict + def is_beautiful_prompt(self, state_dict): param_name = "transformer.h.9.self_attention.query_key_value.weight" return param_name in state_dict @@ -119,6 +121,7 @@ class ModelManager: def load_beautiful_prompt(self, state_dict, file_path=""): component = "beautiful_prompt" + from transformers import AutoModelForCausalLM model_folder = os.path.dirname(file_path) model = AutoModelForCausalLM.from_pretrained( model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype @@ -126,6 +129,15 @@ class ModelManager: self.model[component] = model self.model_path[component] = file_path + def load_RIFE(self, state_dict, file_path=""): + component = "RIFE" + from ..extensions.RIFE import IFNet + model = IFNet().eval() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(torch.float32).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + def search_for_embeddings(self, state_dict): embeddings = [] for k in state_dict: @@ -163,6 +175,8 @@ class ModelManager: self.load_stable_diffusion(state_dict, components=components, file_path=file_path) elif self.is_beautiful_prompt(state_dict): self.load_beautiful_prompt(state_dict, file_path=file_path) + elif self.is_RIFE(state_dict): + self.load_RIFE(state_dict, file_path=file_path) def load_models(self, file_path_list): for file_path in file_path_list: diff --git a/diffsynth/models/sd_unet.py b/diffsynth/models/sd_unet.py index dcdcb53..a0d937e 100644 --- a/diffsynth/models/sd_unet.py +++ b/diffsynth/models/sd_unet.py @@ -73,7 +73,7 @@ class DownSampler(torch.nn.Module): self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding) self.extra_padding = extra_padding - def forward(self, hidden_states, time_emb, text_emb, res_stack): + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): if self.extra_padding: hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) hidden_states = self.conv(hidden_states) @@ -85,7 +85,7 @@ class UpSampler(torch.nn.Module): super().__init__() self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1) - def forward(self, hidden_states, time_emb, text_emb, res_stack): + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") hidden_states = self.conv(hidden_states) return hidden_states, time_emb, text_emb, res_stack @@ -105,7 +105,7 @@ class ResnetBlock(torch.nn.Module): if in_channels != out_channels: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) - def forward(self, hidden_states, time_emb, text_emb, res_stack): + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): x = hidden_states x = self.norm1(x) x = self.nonlinearity(x) @@ -125,7 +125,7 @@ class ResnetBlock(torch.nn.Module): class AttentionBlock(torch.nn.Module): - def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, cross_attention_dim=None, norm_num_groups=32, eps=1e-5): + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, cross_attention_dim=None, norm_num_groups=32, eps=1e-5, need_proj_out=True): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -141,10 +141,11 @@ class AttentionBlock(torch.nn.Module): ) for d in range(num_layers) ]) + self.need_proj_out = need_proj_out + if need_proj_out: + self.proj_out = torch.nn.Linear(inner_dim, in_channels) - self.proj_out = torch.nn.Linear(inner_dim, in_channels) - - def forward(self, hidden_states, time_emb, text_emb, res_stack): + def forward(self, hidden_states, time_emb, text_emb, res_stack, cross_frame_attention=False, **kwargs): batch, _, height, width = hidden_states.shape residual = hidden_states @@ -153,15 +154,25 @@ class AttentionBlock(torch.nn.Module): hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) + if cross_frame_attention: + hidden_states = hidden_states.reshape(1, batch * height * width, inner_dim) + encoder_hidden_states = text_emb.mean(dim=0, keepdim=True) + else: + encoder_hidden_states = text_emb for block in self.transformer_blocks: hidden_states = block( hidden_states, - encoder_hidden_states=text_emb + encoder_hidden_states=encoder_hidden_states ) + if cross_frame_attention: + hidden_states = hidden_states.reshape(batch, height * width, inner_dim) - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = hidden_states + residual + if self.need_proj_out: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = hidden_states + residual + else: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() return hidden_states, time_emb, text_emb, res_stack @@ -170,7 +181,7 @@ class PushBlock(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, hidden_states, time_emb, text_emb, res_stack): + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): res_stack.append(hidden_states) return hidden_states, time_emb, text_emb, res_stack @@ -179,7 +190,7 @@ class PopBlock(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, hidden_states, time_emb, text_emb, res_stack): + def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs): res_hidden_states = res_stack.pop() hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) return hidden_states, time_emb, text_emb, res_stack diff --git a/diffsynth/models/svd_unet.py b/diffsynth/models/svd_unet.py new file mode 100644 index 0000000..3ff0d1a --- /dev/null +++ b/diffsynth/models/svd_unet.py @@ -0,0 +1,436 @@ +import torch, math +from einops import rearrange, repeat +from .sd_unet import Timesteps, PushBlock, PopBlock, Attention, GEGLU, ResnetBlock, AttentionBlock, DownSampler, UpSampler + + +class TemporalResnetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5): + super().__init__() + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0)) + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) + self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0)) + self.nonlinearity = torch.nn.SiLU() + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, hidden_states, time_emb, text_emb, res_stack): + x = rearrange(hidden_states, "f c h w -> 1 c f h w") + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + if time_emb is not None: + emb = self.nonlinearity(time_emb) + emb = self.time_emb_proj(emb) + emb = repeat(emb, "b c -> b c f 1 1", f=hidden_states.shape[0]) + x = x + emb + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.conv2(x) + if self.conv_shortcut is not None: + hidden_states = self.conv_shortcut(hidden_states) + x = rearrange(x[0], "c f h w -> f c h w") + hidden_states = hidden_states + x + return hidden_states, time_emb, text_emb, res_stack + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class TemporalTimesteps(torch.nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps): + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + ) + return t_emb + + +class TemporalAttentionBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, cross_attention_dim=None): + super().__init__() + + self.positional_embedding = TemporalTimesteps(in_channels, True, 0) + self.positional_embedding_proj = torch.nn.Sequential( + torch.nn.Linear(in_channels, in_channels * 4), + torch.nn.SiLU(), + torch.nn.Linear(in_channels * 4, in_channels) + ) + + self.norm_in = torch.nn.LayerNorm(in_channels) + self.act_fn_in = GEGLU(in_channels, in_channels * 4) + self.ff_in = torch.nn.Linear(in_channels * 4, in_channels) + + self.norm1 = torch.nn.LayerNorm(in_channels) + self.attn1 = Attention( + q_dim=in_channels, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + bias_out=True + ) + + self.norm2 = torch.nn.LayerNorm(in_channels) + self.attn2 = Attention( + q_dim=in_channels, + kv_dim=cross_attention_dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + bias_out=True + ) + + self.norm_out = torch.nn.LayerNorm(in_channels) + self.act_fn_out = GEGLU(in_channels, in_channels * 4) + self.ff_out = torch.nn.Linear(in_channels * 4, in_channels) + + def forward(self, hidden_states, time_emb, text_emb, res_stack): + + batch, inner_dim, height, width = hidden_states.shape + pos_emb = torch.arange(batch) + pos_emb = self.positional_embedding(pos_emb).to(dtype=hidden_states.dtype, device=hidden_states.device) + pos_emb = self.positional_embedding_proj(pos_emb)[None, :, :] + + hidden_states = hidden_states.permute(2, 3, 0, 1).reshape(height * width, batch, inner_dim) + hidden_states = hidden_states + pos_emb + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + hidden_states = self.act_fn_in(hidden_states) + hidden_states = self.ff_in(hidden_states) + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=text_emb.repeat(height * width, 1)) + hidden_states = attn_output + hidden_states + + residual = hidden_states + hidden_states = self.norm_out(hidden_states) + hidden_states = self.act_fn_out(hidden_states) + hidden_states = self.ff_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states = hidden_states.reshape(height, width, batch, inner_dim).permute(2, 3, 0, 1) + + return hidden_states, time_emb, text_emb, res_stack + + +class PopMixBlock(torch.nn.Module): + def __init__(self, in_channels=None): + super().__init__() + self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5])) + self.need_proj = in_channels is not None + if self.need_proj: + self.proj = torch.nn.Linear(in_channels, in_channels) + + def forward(self, hidden_states, time_emb, text_emb, res_stack): + res_hidden_states = res_stack.pop() + alpha = torch.sigmoid(self.mix_factor) + hidden_states = alpha * res_hidden_states + (1 - alpha) * hidden_states + if self.need_proj: + hidden_states = hidden_states.permute(0, 2, 3, 1) + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.permute(0, 3, 1, 2) + res_hidden_states = res_stack.pop() + hidden_states = hidden_states + res_hidden_states + return hidden_states, time_emb, text_emb, res_stack + + +class SVDUNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.time_proj = Timesteps(320) + self.time_embedding = torch.nn.Sequential( + torch.nn.Linear(320, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.add_time_proj = Timesteps(256) + self.add_time_embedding = torch.nn.Sequential( + torch.nn.Linear(768, 1280), + torch.nn.SiLU(), + torch.nn.Linear(1280, 1280) + ) + self.conv_in = torch.nn.Conv2d(8, 320, kernel_size=3, padding=1) + + self.blocks = torch.nn.ModuleList([ + # CrossAttnDownBlockSpatioTemporal + ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(), + ResnetBlock(320, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), PushBlock(), + DownSampler(320), PushBlock(), + # CrossAttnDownBlockSpatioTemporal + ResnetBlock(320, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(), + ResnetBlock(640, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), PushBlock(), + DownSampler(640), PushBlock(), + # CrossAttnDownBlockSpatioTemporal + ResnetBlock(640, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(), + ResnetBlock(1280, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), PushBlock(), + DownSampler(1280), PushBlock(), + # DownBlockSpatioTemporal + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), + # UNetMidBlockSpatioTemporal + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), + ResnetBlock(1280, 1280, 1280, eps=1e-5), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), + # UpBlockSpatioTemporal + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-5), PopMixBlock(), + UpSampler(1280), + # CrossAttnUpBlockSpatioTemporal + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), + PopBlock(), ResnetBlock(2560, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), + PopBlock(), ResnetBlock(1920, 1280, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(1280, 1280, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(20, 64, 1280, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(20, 64, 1280, 1024), PopMixBlock(1280), + UpSampler(1280), + # CrossAttnUpBlockSpatioTemporal + PopBlock(), ResnetBlock(1920, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), + PopBlock(), ResnetBlock(1280, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), + PopBlock(), ResnetBlock(960, 640, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(640, 640, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(10, 64, 640, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(10, 64, 640, 1024), PopMixBlock(640), + UpSampler(640), + # CrossAttnUpBlockSpatioTemporal + PopBlock(), ResnetBlock(960, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), + PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), + PopBlock(), ResnetBlock(640, 320, 1280, eps=1e-6), PushBlock(), TemporalResnetBlock(320, 320, 1280, eps=1e-6), PopMixBlock(), PushBlock(), + AttentionBlock(5, 64, 320, 1, 1024, need_proj_out=False), PushBlock(), TemporalAttentionBlock(5, 64, 320, 1024), PopMixBlock(320), + ]) + + self.conv_norm_out = torch.nn.GroupNorm(32, 320, eps=1e-05, affine=True) + self.conv_act = torch.nn.SiLU() + self.conv_out = torch.nn.Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + + def forward(self, sample, timestep, encoder_hidden_states, add_time_id, **kwargs): + # 1. time + t_emb = self.time_proj(timestep[None]).to(sample.dtype) + t_emb = self.time_embedding(t_emb) + + add_embeds = self.add_time_proj(add_time_id.flatten()).to(sample.dtype) + add_embeds = add_embeds.reshape((-1, 768)) + add_embeds = self.add_time_embedding(add_embeds) + + time_emb = t_emb + add_embeds + + # 2. pre-process + height, width = sample.shape[2], sample.shape[3] + hidden_states = self.conv_in(sample) + text_emb = encoder_hidden_states + res_stack = [hidden_states] + + # 3. blocks + for i, block in enumerate(self.blocks): + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + + # 4. output + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + def state_dict_converter(self): + return SVDUNetStateDictConverter() + + + +class SVDUNetStateDictConverter: + def __init__(self): + pass + + def get_block_name(self, names): + if names[0] in ["down_blocks", "mid_block", "up_blocks"]: + if names[4] in ["norm", "proj_in"]: + return ".".join(names[:4] + ["transformer_blocks"]) + elif names[4] in ["time_pos_embed"]: + return ".".join(names[:4] + ["temporal_transformer_blocks"]) + elif names[4] in ["proj_out"]: + return ".".join(names[:4] + ["time_mixer"]) + else: + return ".".join(names[:5]) + return "" + + def from_diffusers(self, state_dict): + rename_dict = { + "time_embedding.linear_1": "time_embedding.0", + "time_embedding.linear_2": "time_embedding.2", + "add_embedding.linear_1": "add_time_embedding.0", + "add_embedding.linear_2": "add_time_embedding.2", + "conv_in": "conv_in", + "conv_norm_out": "conv_norm_out", + "conv_out": "conv_out", + } + blocks_rename_dict = [ + "down_blocks.0.resnets.0.spatial_res_block", None, "down_blocks.0.resnets.0.temporal_res_block", "down_blocks.0.resnets.0.time_mixer", None, + "down_blocks.0.attentions.0.transformer_blocks", None, "down_blocks.0.attentions.0.temporal_transformer_blocks", "down_blocks.0.attentions.0.time_mixer", None, + "down_blocks.0.resnets.1.spatial_res_block", None, "down_blocks.0.resnets.1.temporal_res_block", "down_blocks.0.resnets.1.time_mixer", None, + "down_blocks.0.attentions.1.transformer_blocks", None, "down_blocks.0.attentions.1.temporal_transformer_blocks", "down_blocks.0.attentions.1.time_mixer", None, + "down_blocks.0.downsamplers.0.conv", None, + "down_blocks.1.resnets.0.spatial_res_block", None, "down_blocks.1.resnets.0.temporal_res_block", "down_blocks.1.resnets.0.time_mixer", None, + "down_blocks.1.attentions.0.transformer_blocks", None, "down_blocks.1.attentions.0.temporal_transformer_blocks", "down_blocks.1.attentions.0.time_mixer", None, + "down_blocks.1.resnets.1.spatial_res_block", None, "down_blocks.1.resnets.1.temporal_res_block", "down_blocks.1.resnets.1.time_mixer", None, + "down_blocks.1.attentions.1.transformer_blocks", None, "down_blocks.1.attentions.1.temporal_transformer_blocks", "down_blocks.1.attentions.1.time_mixer", None, + "down_blocks.1.downsamplers.0.conv", None, + "down_blocks.2.resnets.0.spatial_res_block", None, "down_blocks.2.resnets.0.temporal_res_block", "down_blocks.2.resnets.0.time_mixer", None, + "down_blocks.2.attentions.0.transformer_blocks", None, "down_blocks.2.attentions.0.temporal_transformer_blocks", "down_blocks.2.attentions.0.time_mixer", None, + "down_blocks.2.resnets.1.spatial_res_block", None, "down_blocks.2.resnets.1.temporal_res_block", "down_blocks.2.resnets.1.time_mixer", None, + "down_blocks.2.attentions.1.transformer_blocks", None, "down_blocks.2.attentions.1.temporal_transformer_blocks", "down_blocks.2.attentions.1.time_mixer", None, + "down_blocks.2.downsamplers.0.conv", None, + "down_blocks.3.resnets.0.spatial_res_block", None, "down_blocks.3.resnets.0.temporal_res_block", "down_blocks.3.resnets.0.time_mixer", None, + "down_blocks.3.resnets.1.spatial_res_block", None, "down_blocks.3.resnets.1.temporal_res_block", "down_blocks.3.resnets.1.time_mixer", None, + "mid_block.mid_block.resnets.0.spatial_res_block", None, "mid_block.mid_block.resnets.0.temporal_res_block", "mid_block.mid_block.resnets.0.time_mixer", None, + "mid_block.mid_block.attentions.0.transformer_blocks", None, "mid_block.mid_block.attentions.0.temporal_transformer_blocks", "mid_block.mid_block.attentions.0.time_mixer", + "mid_block.mid_block.resnets.1.spatial_res_block", None, "mid_block.mid_block.resnets.1.temporal_res_block", "mid_block.mid_block.resnets.1.time_mixer", + None, "up_blocks.0.resnets.0.spatial_res_block", None, "up_blocks.0.resnets.0.temporal_res_block", "up_blocks.0.resnets.0.time_mixer", + None, "up_blocks.0.resnets.1.spatial_res_block", None, "up_blocks.0.resnets.1.temporal_res_block", "up_blocks.0.resnets.1.time_mixer", + None, "up_blocks.0.resnets.2.spatial_res_block", None, "up_blocks.0.resnets.2.temporal_res_block", "up_blocks.0.resnets.2.time_mixer", + "up_blocks.0.upsamplers.0.conv", + None, "up_blocks.1.resnets.0.spatial_res_block", None, "up_blocks.1.resnets.0.temporal_res_block", "up_blocks.1.resnets.0.time_mixer", None, + "up_blocks.1.attentions.0.transformer_blocks", None, "up_blocks.1.attentions.0.temporal_transformer_blocks", "up_blocks.1.attentions.0.time_mixer", + None, "up_blocks.1.resnets.1.spatial_res_block", None, "up_blocks.1.resnets.1.temporal_res_block", "up_blocks.1.resnets.1.time_mixer", None, + "up_blocks.1.attentions.1.transformer_blocks", None, "up_blocks.1.attentions.1.temporal_transformer_blocks", "up_blocks.1.attentions.1.time_mixer", + None, "up_blocks.1.resnets.2.spatial_res_block", None, "up_blocks.1.resnets.2.temporal_res_block", "up_blocks.1.resnets.2.time_mixer", None, + "up_blocks.1.attentions.2.transformer_blocks", None, "up_blocks.1.attentions.2.temporal_transformer_blocks", "up_blocks.1.attentions.2.time_mixer", + "up_blocks.1.upsamplers.0.conv", + None, "up_blocks.2.resnets.0.spatial_res_block", None, "up_blocks.2.resnets.0.temporal_res_block", "up_blocks.2.resnets.0.time_mixer", None, + "up_blocks.2.attentions.0.transformer_blocks", None, "up_blocks.2.attentions.0.temporal_transformer_blocks", "up_blocks.2.attentions.0.time_mixer", + None, "up_blocks.2.resnets.1.spatial_res_block", None, "up_blocks.2.resnets.1.temporal_res_block", "up_blocks.2.resnets.1.time_mixer", None, + "up_blocks.2.attentions.1.transformer_blocks", None, "up_blocks.2.attentions.1.temporal_transformer_blocks", "up_blocks.2.attentions.1.time_mixer", + None, "up_blocks.2.resnets.2.spatial_res_block", None, "up_blocks.2.resnets.2.temporal_res_block", "up_blocks.2.resnets.2.time_mixer", None, + "up_blocks.2.attentions.2.transformer_blocks", None, "up_blocks.2.attentions.2.temporal_transformer_blocks", "up_blocks.2.attentions.2.time_mixer", + "up_blocks.2.upsamplers.0.conv", + None, "up_blocks.3.resnets.0.spatial_res_block", None, "up_blocks.3.resnets.0.temporal_res_block", "up_blocks.3.resnets.0.time_mixer", None, + "up_blocks.3.attentions.0.transformer_blocks", None, "up_blocks.3.attentions.0.temporal_transformer_blocks", "up_blocks.3.attentions.0.time_mixer", + None, "up_blocks.3.resnets.1.spatial_res_block", None, "up_blocks.3.resnets.1.temporal_res_block", "up_blocks.3.resnets.1.time_mixer", None, + "up_blocks.3.attentions.1.transformer_blocks", None, "up_blocks.3.attentions.1.temporal_transformer_blocks", "up_blocks.3.attentions.1.time_mixer", + None, "up_blocks.3.resnets.2.spatial_res_block", None, "up_blocks.3.resnets.2.temporal_res_block", "up_blocks.3.resnets.2.time_mixer", None, + "up_blocks.3.attentions.2.transformer_blocks", None, "up_blocks.3.attentions.2.temporal_transformer_blocks", "up_blocks.3.attentions.2.time_mixer", + ] + blocks_rename_dict = {i:j for j,i in enumerate(blocks_rename_dict) if i is not None} + state_dict_ = {} + for name, param in sorted(state_dict.items()): + names = name.split(".") + if names[0] == "mid_block": + names = ["mid_block"] + names + if names[-1] in ["weight", "bias"]: + name_prefix = ".".join(names[:-1]) + if name_prefix in rename_dict: + state_dict_[rename_dict[name_prefix] + "." + names[-1]] = param + else: + block_name = self.get_block_name(names) + if "resnets" in block_name and block_name in blocks_rename_dict: + rename = ".".join(["blocks", str(blocks_rename_dict[block_name])] + names[5:]) + state_dict_[rename] = param + elif ("downsamplers" in block_name or "upsamplers" in block_name) and block_name in blocks_rename_dict: + rename = ".".join(["blocks", str(blocks_rename_dict[block_name])] + names[-2:]) + state_dict_[rename] = param + elif "attentions" in block_name and block_name in blocks_rename_dict: + attention_id = names[5] + if "transformer_blocks" in names: + suffix_dict = { + "attn1.to_out.0": "attn1.to_out", + "attn2.to_out.0": "attn2.to_out", + "ff.net.0.proj": "act_fn.proj", + "ff.net.2": "ff", + } + suffix = ".".join(names[6:-1]) + suffix = suffix_dict.get(suffix, suffix) + rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), "transformer_blocks", attention_id, suffix, names[-1]]) + elif "temporal_transformer_blocks" in names: + suffix_dict = { + "attn1.to_out.0": "attn1.to_out", + "attn2.to_out.0": "attn2.to_out", + "ff_in.net.0.proj": "act_fn_in.proj", + "ff_in.net.2": "ff_in", + "ff.net.0.proj": "act_fn_out.proj", + "ff.net.2": "ff_out", + "norm3": "norm_out", + } + suffix = ".".join(names[6:-1]) + suffix = suffix_dict.get(suffix, suffix) + rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), suffix, names[-1]]) + elif "time_mixer" in block_name: + rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), "proj", names[-1]]) + else: + suffix_dict = { + "linear_1": "positional_embedding_proj.0", + "linear_2": "positional_embedding_proj.2", + } + suffix = names[-2] + suffix = suffix_dict.get(suffix, suffix) + rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), suffix, names[-1]]) + state_dict_[rename] = param + else: + print(name) + else: + block_name = self.get_block_name(names) + if len(block_name)>0 and block_name in blocks_rename_dict: + rename = ".".join(["blocks", str(blocks_rename_dict[block_name]), names[-1]]) + state_dict_[rename] = param + return state_dict_ diff --git a/diffsynth/pipelines/dancer.py b/diffsynth/pipelines/dancer.py index 800876b..67ff24e 100644 --- a/diffsynth/pipelines/dancer.py +++ b/diffsynth/pipelines/dancer.py @@ -15,6 +15,7 @@ def lets_dance( controlnet_frames = None, unet_batch_size = 1, controlnet_batch_size = 1, + cross_frame_attention = False, tiled=False, tile_size=64, tile_stride=32, @@ -86,7 +87,13 @@ def lets_dance( tile_dtype=hidden_states.dtype ) else: - hidden_states, _, _, _ = block(hidden_states_input[batch_id: batch_id_], time_emb, text_emb[batch_id: batch_id_], res_stack) + hidden_states, _, _, _ = block( + hidden_states_input[batch_id: batch_id_], + time_emb, + text_emb[batch_id: batch_id_], + res_stack, + cross_frame_attention=cross_frame_attention + ) hidden_states_output.append(hidden_states) hidden_states = torch.concat(hidden_states_output, dim=0) # 4.2 AnimateDiff diff --git a/diffsynth/pipelines/stable_diffusion_video.py b/diffsynth/pipelines/stable_diffusion_video.py index 1aa2f3e..6b78261 100644 --- a/diffsynth/pipelines/stable_diffusion_video.py +++ b/diffsynth/pipelines/stable_diffusion_video.py @@ -18,10 +18,11 @@ def lets_dance_with_long_video( timestep = None, encoder_hidden_states = None, controlnet_frames = None, - unet_batch_size = 1, - controlnet_batch_size = 1, animatediff_batch_size = 16, animatediff_stride = 8, + unet_batch_size = 1, + controlnet_batch_size = 1, + cross_frame_attention = False, device = "cuda", vram_limit_level = 0, ): @@ -38,12 +39,14 @@ def lets_dance_with_long_video( timestep, encoder_hidden_states[batch_id: batch_id_].to(device), controlnet_frames[:, batch_id: batch_id_].to(device) if controlnet_frames is not None else None, - unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, device=device, vram_limit_level=vram_limit_level + unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, + cross_frame_attention=cross_frame_attention, + device=device, vram_limit_level=vram_limit_level ).cpu() # update hidden_states for i, hidden_states_updated in zip(range(batch_id, batch_id_), hidden_states_batch): - bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1) / 2), 1e-2) + bias = max(1 - abs(i - (batch_id + batch_id_ - 1) / 2) / ((batch_id_ - batch_id - 1 + 1e-2) / 2), 1e-2) hidden_states, num = hidden_states_output[i] hidden_states = hidden_states * (num / (num + bias)) + hidden_states_updated * (bias / (num + bias)) hidden_states_output[i] = (hidden_states, num + 1) @@ -159,6 +162,13 @@ class SDVideoPipeline(torch.nn.Module): height=512, width=512, num_inference_steps=20, + animatediff_batch_size = 16, + animatediff_stride = 8, + unet_batch_size = 1, + controlnet_batch_size = 1, + cross_frame_attention = False, + smoother=None, + smoother_progress_ids=[], vram_limit_level=0, progress_bar_cmd=tqdm, progress_bar_st=None, @@ -167,8 +177,11 @@ class SDVideoPipeline(torch.nn.Module): self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors - noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype) - if input_frames is None: + if self.motion_modules is None: + noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1) + else: + noise = torch.randn((num_frames, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype) + if input_frames is None or denoising_strength == 1.0: latents = noise else: latents = self.encode_images(input_frames) @@ -195,16 +208,28 @@ class SDVideoPipeline(torch.nn.Module): noise_pred_posi = lets_dance_with_long_video( self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet, sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames, + animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride, + unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, + cross_frame_attention=cross_frame_attention, device=self.device, vram_limit_level=vram_limit_level ) noise_pred_nega = lets_dance_with_long_video( self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet, sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames, + animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride, + unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, + cross_frame_attention=cross_frame_attention, device=self.device, vram_limit_level=vram_limit_level ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) - # DDIM + # DDIM and smoother + if smoother is not None and progress_id in smoother_progress_ids: + rendered_frames = self.scheduler.step(noise_pred, timestep, latents, to_final=True) + rendered_frames = self.decode_images(rendered_frames) + rendered_frames = smoother(rendered_frames, original_frames=input_frames) + target_latents = self.encode_images(rendered_frames) + noise_pred = self.scheduler.return_to_timestep(timestep, latents, target_latents) latents = self.scheduler.step(noise_pred, timestep, latents) # UI diff --git a/examples/sd_text_to_video.py b/examples/sd_text_to_video.py new file mode 100644 index 0000000..a92d71a --- /dev/null +++ b/examples/sd_text_to_video.py @@ -0,0 +1,47 @@ +from diffsynth import ModelManager, SDImagePipeline, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames +from diffsynth.extensions.RIFE import RIFEInterpolater +import torch + + +# Download models +# `models/stable_diffusion/dreamshaper_8.safetensors`: [link](https://civitai.com/api/download/models/128713?type=Model&format=SafeTensor&size=pruned&fp=fp16) +# `models/AnimateDiff/mm_sd_v15_v2.ckpt`: [link](https://huggingface.co/guoyww/animatediff/resolve/main/mm_sd_v15_v2.ckpt) +# `models/RIFE/flownet.pkl`: [link](https://drive.google.com/file/d/1APIzVeI-4ZZCEuIRE1m6WYfSCaOsi_7_/view?usp=sharing) + + +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_models([ + "models/stable_diffusion/dreamshaper_8.safetensors", + "models/AnimateDiff/mm_sd_v15_v2.ckpt", + "models/RIFE/flownet.pkl" +]) + +# Text -> Image +pipe_image = SDImagePipeline.from_model_manager(model_manager) +torch.manual_seed(0) +image = pipe_image( + prompt = "lightning storm, sea", + negative_prompt = "", + cfg_scale=7.5, + num_inference_steps=30, height=512, width=768, +) + +# Text + Image -> Video (6GB VRAM is enough!) +pipe = SDVideoPipeline.from_model_manager(model_manager) +output_video = pipe( + prompt = "lightning storm, sea", + negative_prompt = "", + cfg_scale=7.5, + num_frames=64, + num_inference_steps=10, height=512, width=768, + animatediff_batch_size=16, animatediff_stride=1, input_frames=[image]*64, denoising_strength=0.9, + vram_limit_level=0, +) + +# Video -> Video with high fps +interpolater = RIFEInterpolater.from_model_manager(model_manager) +output_video = interpolater.interpolate(output_video, num_iter=3) + +# Save images and video +save_video(output_video, "output_video.mp4", fps=120) diff --git a/examples/sd_toon_shading.py b/examples/sd_toon_shading.py index df9f3a9..8aadff3 100644 --- a/examples/sd_toon_shading.py +++ b/examples/sd_toon_shading.py @@ -1,4 +1,5 @@ from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video, save_frames +from diffsynth.extensions.RIFE import RIFESmoother import torch @@ -9,6 +10,8 @@ import torch # `models/ControlNet/control_v11f1e_sd15_tile.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1e_sd15_tile.pth) # `models/Annotators/sk_model.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model.pth) # `models/Annotators/sk_model2.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/sk_model2.pth) +# `models/textual_inversion/verybadimagenegative_v1.3.pt`: [link](https://civitai.com/api/download/models/25820?type=Model&format=PickleTensor&size=full&fp=fp16) +# `models/RIFE/flownet.pkl`: [link](https://drive.google.com/file/d/1APIzVeI-4ZZCEuIRE1m6WYfSCaOsi_7_/view?usp=sharing) # Load models @@ -19,6 +22,7 @@ model_manager.load_models([ "models/AnimateDiff/mm_sd_v15_v2.ckpt", "models/ControlNet/control_v11p_sd15_lineart.pth", "models/ControlNet/control_v11f1e_sd15_tile.pth", + "models/RIFE/flownet.pkl" ]) pipe = SDVideoPipeline.from_model_manager( model_manager, @@ -26,31 +30,36 @@ pipe = SDVideoPipeline.from_model_manager( ControlNetConfigUnit( processor_id="lineart", model_path="models/ControlNet/control_v11p_sd15_lineart.pth", - scale=1.0 + scale=0.5 ), ControlNetConfigUnit( processor_id="tile", model_path="models/ControlNet/control_v11f1e_sd15_tile.pth", scale=0.5 - ), + ) ] ) +smoother = RIFESmoother.from_model_manager(model_manager) -# Load video (we only use 16 frames in this example for testing) -video = VideoData(video_file="input_video.mp4", height=1536, width=1536) -input_video = [video[i] for i in range(16)] +# Load video (we only use 60 frames for quick testing) +# The original video is here: https://www.bilibili.com/video/BV19w411A7YJ/ +video = VideoData( + video_file="data/bilibili_videos/៸៸᳐_⩊_៸៸᳐ 66 微笑调查队🌻/៸៸᳐_⩊_៸៸᳐ 66 微笑调查队🌻 - 1.66 微笑调查队🌻(Av278681824,P1).mp4", + height=1024, width=1024) +input_video = [video[i] for i in range(40*60, 41*60)] -# Toon shading +# Toon shading (20G VRAM) torch.manual_seed(0) output_video = pipe( prompt="best quality, perfect anime illustration, light, a girl is dancing, smile, solo", negative_prompt="verybadimagenegative_v1.3", - cfg_scale=5, clip_skip=2, + cfg_scale=3, clip_skip=2, controlnet_frames=input_video, num_frames=len(input_video), - num_inference_steps=10, height=1536, width=1536, + num_inference_steps=10, height=1024, width=1024, + animatediff_batch_size=32, animatediff_stride=16, vram_limit_level=0, ) +output_video = smoother(output_video) -# Save images and video -save_frames(output_video, "output_frames") -save_video(output_video, "output_video.mp4", fps=16) +# Save video +save_video(output_video, "output_video.mp4", fps=60) diff --git a/examples/sd_video_rerender.py b/examples/sd_video_rerender.py new file mode 100644 index 0000000..75748c5 --- /dev/null +++ b/examples/sd_video_rerender.py @@ -0,0 +1,58 @@ +from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video +from diffsynth.extensions.FastBlend import FastBlendSmoother +import torch + + +# Download models +# `models/stable_diffusion/dreamshaper_8.safetensors`: [link](https://civitai.com/api/download/models/128713?type=Model&format=SafeTensor&size=pruned&fp=fp16) +# `models/ControlNet/control_v11f1p_sd15_depth.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11f1p_sd15_depth.pth) +# `models/ControlNet/control_v11p_sd15_softedge.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_softedge.pth) +# `models/Annotators/dpt_hybrid-midas-501f0c75.pt`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/dpt_hybrid-midas-501f0c75.pt) +# `models/Annotators/ControlNetHED.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth) +# `models/RIFE/flownet.pkl`: [link](https://drive.google.com/file/d/1APIzVeI-4ZZCEuIRE1m6WYfSCaOsi_7_/view?usp=sharing) + + +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_models([ + "models/stable_diffusion/dreamshaper_8.safetensors", + "models/ControlNet/control_v11f1p_sd15_depth.pth", + "models/ControlNet/control_v11p_sd15_softedge.pth", + "models/RIFE/flownet.pkl" +]) +pipe = SDVideoPipeline.from_model_manager( + model_manager, + [ + ControlNetConfigUnit( + processor_id="depth", + model_path=rf"models/ControlNet/control_v11f1p_sd15_depth.pth", + scale=0.5 + ), + ControlNetConfigUnit( + processor_id="softedge", + model_path=rf"models/ControlNet/control_v11p_sd15_softedge.pth", + scale=0.5 + ) + ] +) +smoother = FastBlendSmoother.from_model_manager(model_manager) + +# Load video +# Original video: https://pixabay.com/videos/flow-rocks-water-fluent-stones-159627/ +video = VideoData(video_file="data/pixabay100/159627 (1080p).mp4", height=512, width=768) +input_video = [video[i] for i in range(128)] + +# Rerender +torch.manual_seed(0) +output_video = pipe( + prompt="winter, ice, snow, water, river", + negative_prompt="", cfg_scale=7, + input_frames=input_video, controlnet_frames=input_video, num_frames=len(input_video), + num_inference_steps=10, height=512, width=768, + animatediff_batch_size=32, animatediff_stride=16, unet_batch_size=4, + cross_frame_attention=True, + smoother=smoother, smoother_progress_ids=[4, 9] +) + +# Save images and video +save_video(output_video, "output_video.mp4", fps=30)