From eab2dcbf7b3cb8096ffc417fb90712aba77aefc4 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 5 Feb 2024 01:43:55 +0800 Subject: [PATCH] update smoothers --- diffsynth/processors/FastBlend.py | 142 +++++++++++++++++++ diffsynth/processors/PILEditor.py | 28 ++++ diffsynth/processors/RIFE.py | 77 ++++++++++ diffsynth/processors/__init__.py | 0 diffsynth/processors/base.py | 6 + diffsynth/processors/sequencial_processor.py | 15 ++ examples/sd_video_rerender.py | 17 ++- 7 files changed, 276 insertions(+), 9 deletions(-) create mode 100644 diffsynth/processors/FastBlend.py create mode 100644 diffsynth/processors/PILEditor.py create mode 100644 diffsynth/processors/RIFE.py create mode 100644 diffsynth/processors/__init__.py create mode 100644 diffsynth/processors/base.py create mode 100644 diffsynth/processors/sequencial_processor.py diff --git a/diffsynth/processors/FastBlend.py b/diffsynth/processors/FastBlend.py new file mode 100644 index 0000000..fed33f4 --- /dev/null +++ b/diffsynth/processors/FastBlend.py @@ -0,0 +1,142 @@ +from PIL import Image +import cupy as cp +import numpy as np +from tqdm import tqdm +from ..extensions.FastBlend.patch_match import PyramidPatchMatcher +from ..extensions.FastBlend.runners.fast import TableManager +from .base import VideoProcessor + + +class FastBlendSmoother(VideoProcessor): + def __init__( + self, + inference_mode="fast", batch_size=8, window_size=60, + minimum_patch_size=5, threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0, initialize="identity", tracking_window_size=0 + ): + self.inference_mode = inference_mode + self.batch_size = batch_size + self.window_size = window_size + self.ebsynth_config = { + "minimum_patch_size": minimum_patch_size, + "threads_per_block": threads_per_block, + "num_iter": num_iter, + "gpu_id": gpu_id, + "guide_weight": guide_weight, + "initialize": initialize, + "tracking_window_size": tracking_window_size + } + + @staticmethod + def from_model_manager(model_manager, **kwargs): + # TODO: fetch GPU ID from model_manager + return FastBlendSmoother(**kwargs) + + def inference_fast(self, frames_guide, frames_style): + table_manager = TableManager() + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + **self.ebsynth_config + ) + # left part + table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, self.batch_size, desc="Fast Mode Step 1/4") + table_l = table_manager.remapping_table_to_blending_table(table_l) + table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 2/4") + # right part + table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, self.batch_size, desc="Fast Mode Step 3/4") + table_r = table_manager.remapping_table_to_blending_table(table_r) + table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, self.window_size, self.batch_size, desc="Fast Mode Step 4/4")[::-1] + # merge + frames = [] + for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r): + weight_m = -1 + weight = weight_l + weight_m + weight_r + frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight) + frames.append(frame) + frames = [frame.clip(0, 255).astype("uint8") for frame in frames] + frames = [Image.fromarray(frame) for frame in frames] + return frames + + def inference_balanced(self, frames_guide, frames_style): + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + **self.ebsynth_config + ) + output_frames = [] + # tasks + n = len(frames_style) + tasks = [] + for target in range(n): + for source in range(target - self.window_size, target + self.window_size + 1): + if source >= 0 and source < n and source != target: + tasks.append((source, target)) + # run + frames = [(None, 1) for i in range(n)] + for batch_id in tqdm(range(0, len(tasks), self.batch_size), desc="Balanced Mode"): + tasks_batch = tasks[batch_id: min(batch_id+self.batch_size, len(tasks))] + source_guide = np.stack([frames_guide[source] for source, target in tasks_batch]) + target_guide = np.stack([frames_guide[target] for source, target in tasks_batch]) + source_style = np.stack([frames_style[source] for source, target in tasks_batch]) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + for (source, target), result in zip(tasks_batch, target_style): + frame, weight = frames[target] + if frame is None: + frame = frames_style[target] + frames[target] = ( + frame * (weight / (weight + 1)) + result / (weight + 1), + weight + 1 + ) + if weight + 1 == min(n, target + self.window_size + 1) - max(0, target - self.window_size): + frame = frame.clip(0, 255).astype("uint8") + output_frames.append(Image.fromarray(frame)) + frames[target] = (None, 1) + return output_frames + + def inference_accurate(self, frames_guide, frames_style): + patch_match_engine = PyramidPatchMatcher( + image_height=frames_style[0].shape[0], + image_width=frames_style[0].shape[1], + channel=3, + use_mean_target_style=True, + **self.ebsynth_config + ) + output_frames = [] + # run + n = len(frames_style) + for target in tqdm(range(n), desc="Accurate Mode"): + l, r = max(target - self.window_size, 0), min(target + self.window_size + 1, n) + remapped_frames = [] + for i in range(l, r, self.batch_size): + j = min(i + self.batch_size, r) + source_guide = np.stack([frames_guide[source] for source in range(i, j)]) + target_guide = np.stack([frames_guide[target]] * (j - i)) + source_style = np.stack([frames_style[source] for source in range(i, j)]) + _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style) + remapped_frames.append(target_style) + frame = np.concatenate(remapped_frames, axis=0).mean(axis=0) + frame = frame.clip(0, 255).astype("uint8") + output_frames.append(Image.fromarray(frame)) + return output_frames + + def release_vram(self): + mempool = cp.get_default_memory_pool() + pinned_mempool = cp.get_default_pinned_memory_pool() + mempool.free_all_blocks() + pinned_mempool.free_all_blocks() + + def __call__(self, rendered_frames, original_frames=None, **kwargs): + rendered_frames = [np.array(frame) for frame in rendered_frames] + original_frames = [np.array(frame) for frame in original_frames] + if self.inference_mode == "fast": + output_frames = self.inference_fast(original_frames, rendered_frames) + elif self.inference_mode == "balanced": + output_frames = self.inference_balanced(original_frames, rendered_frames) + elif self.inference_mode == "accurate": + output_frames = self.inference_accurate(original_frames, rendered_frames) + else: + raise ValueError("inference_mode must be fast, balanced or accurate") + self.release_vram() + return output_frames diff --git a/diffsynth/processors/PILEditor.py b/diffsynth/processors/PILEditor.py new file mode 100644 index 0000000..01011d8 --- /dev/null +++ b/diffsynth/processors/PILEditor.py @@ -0,0 +1,28 @@ +from PIL import ImageEnhance +from .base import VideoProcessor + + +class ContrastEditor(VideoProcessor): + def __init__(self, rate=1.5): + self.rate = rate + + @staticmethod + def from_model_manager(model_manager, **kwargs): + return ContrastEditor(**kwargs) + + def __call__(self, rendered_frames, **kwargs): + rendered_frames = [ImageEnhance.Contrast(i).enhance(self.rate) for i in rendered_frames] + return rendered_frames + + +class SharpnessEditor(VideoProcessor): + def __init__(self, rate=1.5): + self.rate = rate + + @staticmethod + def from_model_manager(model_manager, **kwargs): + return SharpnessEditor(**kwargs) + + def __call__(self, rendered_frames, **kwargs): + rendered_frames = [ImageEnhance.Sharpness(i).enhance(self.rate) for i in rendered_frames] + return rendered_frames diff --git a/diffsynth/processors/RIFE.py b/diffsynth/processors/RIFE.py new file mode 100644 index 0000000..4186eb3 --- /dev/null +++ b/diffsynth/processors/RIFE.py @@ -0,0 +1,77 @@ +import torch +import numpy as np +from PIL import Image +from .base import VideoProcessor + + +class RIFESmoother(VideoProcessor): + def __init__(self, model, device="cuda", scale=1.0, batch_size=4, interpolate=True): + self.model = model + self.device = device + + # IFNet only does not support float16 + self.torch_dtype = torch.float32 + + # Other parameters + self.scale = scale + self.batch_size = batch_size + self.interpolate = interpolate + + @staticmethod + def from_model_manager(model_manager, **kwargs): + return RIFESmoother(model_manager.RIFE, device=model_manager.device, **kwargs) + + def process_image(self, image): + width, height = image.size + if width % 32 != 0 or height % 32 != 0: + width = (width + 31) // 32 + height = (height + 31) // 32 + image = image.resize((width, height)) + image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1) + return image + + def process_images(self, images): + images = [self.process_image(image) for image in images] + images = torch.stack(images) + return images + + def decode_images(self, images): + images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8) + images = [Image.fromarray(image) for image in images] + return images + + def process_tensors(self, input_tensor, scale=1.0, batch_size=4): + output_tensor = [] + for batch_id in range(0, input_tensor.shape[0], batch_size): + batch_id_ = min(batch_id + batch_size, input_tensor.shape[0]) + batch_input_tensor = input_tensor[batch_id: batch_id_] + batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype) + flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale]) + output_tensor.append(merged[2].cpu()) + output_tensor = torch.concat(output_tensor, dim=0) + return output_tensor + + @torch.no_grad() + def __call__(self, rendered_frames, **kwargs): + # Preprocess + processed_images = self.process_images(rendered_frames) + + # Input + input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1) + + # Interpolate + output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size) + + if self.interpolate: + # Blend + input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1) + output_tensor = self.process_tensors(input_tensor, scale=self.scale, batch_size=self.batch_size) + processed_images[1:-1] = output_tensor + else: + processed_images[1:-1] = (processed_images[1:-1] + output_tensor) / 2 + + # To images + output_images = self.decode_images(processed_images) + if output_images[0].size != rendered_frames[0].size: + output_images = [image.resize(rendered_frames[0].size) for image in output_images] + return output_images diff --git a/diffsynth/processors/__init__.py b/diffsynth/processors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diffsynth/processors/base.py b/diffsynth/processors/base.py new file mode 100644 index 0000000..278a9c1 --- /dev/null +++ b/diffsynth/processors/base.py @@ -0,0 +1,6 @@ +class VideoProcessor: + def __init__(self): + pass + + def __call__(self): + raise NotImplementedError diff --git a/diffsynth/processors/sequencial_processor.py b/diffsynth/processors/sequencial_processor.py new file mode 100644 index 0000000..6f6b440 --- /dev/null +++ b/diffsynth/processors/sequencial_processor.py @@ -0,0 +1,15 @@ +from .base import VideoProcessor + + +class SequencialProcessor(VideoProcessor): + def __init__(self, processors=[]): + self.processors = processors + + @staticmethod + def from_model_manager(model_manager, **kwargs): + return SequencialProcessor(**kwargs) + + def __call__(self, rendered_frames, **kwargs): + for processor in self.processors: + rendered_frames = processor(rendered_frames, **kwargs) + return rendered_frames \ No newline at end of file diff --git a/examples/sd_video_rerender.py b/examples/sd_video_rerender.py index 75748c5..0e82442 100644 --- a/examples/sd_video_rerender.py +++ b/examples/sd_video_rerender.py @@ -1,5 +1,7 @@ from diffsynth import ModelManager, SDVideoPipeline, ControlNetConfigUnit, VideoData, save_video -from diffsynth.extensions.FastBlend import FastBlendSmoother +from diffsynth.processors.FastBlend import FastBlendSmoother +from diffsynth.processors.PILEditor import ContrastEditor, SharpnessEditor +from diffsynth.processors.sequencial_processor import SequencialProcessor import torch @@ -9,16 +11,13 @@ import torch # `models/ControlNet/control_v11p_sd15_softedge.pth`: [link](https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_softedge.pth) # `models/Annotators/dpt_hybrid-midas-501f0c75.pt`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/dpt_hybrid-midas-501f0c75.pt) # `models/Annotators/ControlNetHED.pth`: [link](https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth) -# `models/RIFE/flownet.pkl`: [link](https://drive.google.com/file/d/1APIzVeI-4ZZCEuIRE1m6WYfSCaOsi_7_/view?usp=sharing) - # Load models model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") model_manager.load_models([ "models/stable_diffusion/dreamshaper_8.safetensors", "models/ControlNet/control_v11f1p_sd15_depth.pth", - "models/ControlNet/control_v11p_sd15_softedge.pth", - "models/RIFE/flownet.pkl" + "models/ControlNet/control_v11p_sd15_softedge.pth" ]) pipe = SDVideoPipeline.from_model_manager( model_manager, @@ -35,7 +34,7 @@ pipe = SDVideoPipeline.from_model_manager( ) ] ) -smoother = FastBlendSmoother.from_model_manager(model_manager) +smoother = SequencialProcessor([FastBlendSmoother(), ContrastEditor(rate=1.1), SharpnessEditor(rate=1.1)]) # Load video # Original video: https://pixabay.com/videos/flow-rocks-water-fluent-stones-159627/ @@ -48,10 +47,10 @@ output_video = pipe( prompt="winter, ice, snow, water, river", negative_prompt="", cfg_scale=7, input_frames=input_video, controlnet_frames=input_video, num_frames=len(input_video), - num_inference_steps=10, height=512, width=768, - animatediff_batch_size=32, animatediff_stride=16, unet_batch_size=4, + num_inference_steps=20, height=512, width=768, + animatediff_batch_size=8, animatediff_stride=4, unet_batch_size=8, cross_frame_attention=True, - smoother=smoother, smoother_progress_ids=[4, 9] + smoother=smoother, smoother_progress_ids=[4, 9, 14, 19] ) # Save images and video