support video-to-video-translation

This commit is contained in:
Artiprocher
2023-12-21 17:11:58 +08:00
parent f7f4c1038e
commit c1453281df
20 changed files with 1659 additions and 427 deletions

View File

@@ -0,0 +1,2 @@
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
from .processors import Annotator

View File

@@ -0,0 +1,48 @@
import torch
import numpy as np
from .processors import Processor_id
class ControlNetConfigUnit:
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
self.processor_id = processor_id
self.model_path = model_path
self.scale = scale
class ControlNetUnit:
def __init__(self, processor, model, scale=1.0):
self.processor = processor
self.model = model
self.scale = scale
class MultiControlNetManager:
def __init__(self, controlnet_units=[]):
self.processors = [unit.processor for unit in controlnet_units]
self.models = [unit.model for unit in controlnet_units]
self.scales = [unit.scale for unit in controlnet_units]
def process_image(self, image, return_image=False):
processed_image = [
processor(image)
for processor in self.processors
]
if return_image:
return processed_image
processed_image = torch.concat([
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
for image_ in processed_image
], dim=0)
return processed_image
def __call__(self, sample, timestep, encoder_hidden_states, conditionings):
res_stack = None
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
res_stack_ = model(sample, timestep, encoder_hidden_states, conditioning)
res_stack_ = [res * scale for res in res_stack_]
if res_stack is None:
res_stack = res_stack_
else:
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
return res_stack

View File

@@ -0,0 +1,50 @@
from typing_extensions import Literal, TypeAlias
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from controlnet_aux.processor import (
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
)
Processor_id: TypeAlias = Literal[
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
]
class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=512):
if processor_id == "canny":
self.processor = CannyDetector()
elif processor_id == "depth":
self.processor = MidasDetector.from_pretrained(model_path)
elif processor_id == "softedge":
self.processor = HEDdetector.from_pretrained(model_path)
elif processor_id == "lineart":
self.processor = LineartDetector.from_pretrained(model_path)
elif processor_id == "lineart_anime":
self.processor = LineartAnimeDetector.from_pretrained(model_path)
elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path)
elif processor_id == "tile":
self.processor = None
else:
raise ValueError(f"Unsupported processor_id: {processor_id}")
self.processor_id = processor_id
self.detect_resolution = detect_resolution
def __call__(self, image):
width, height = image.size
if self.processor_id == "openpose":
kwargs = {
"include_body": True,
"include_hand": True,
"include_face": True
}
else:
kwargs = {}
if self.processor is not None:
image = self.processor(image, detect_resolution=self.detect_resolution, image_resolution=min(width, height), **kwargs)
image = image.resize((width, height))
return image