mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support video-to-video-translation
This commit is contained in:
2
diffsynth/controlnets/__init__.py
Normal file
2
diffsynth/controlnets/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
||||
from .processors import Annotator
|
||||
48
diffsynth/controlnets/controlnet_unit.py
Normal file
48
diffsynth/controlnets/controlnet_unit.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from .processors import Processor_id
|
||||
|
||||
|
||||
class ControlNetConfigUnit:
|
||||
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
||||
self.processor_id = processor_id
|
||||
self.model_path = model_path
|
||||
self.scale = scale
|
||||
|
||||
|
||||
class ControlNetUnit:
|
||||
def __init__(self, processor, model, scale=1.0):
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
self.scale = scale
|
||||
|
||||
|
||||
class MultiControlNetManager:
|
||||
def __init__(self, controlnet_units=[]):
|
||||
self.processors = [unit.processor for unit in controlnet_units]
|
||||
self.models = [unit.model for unit in controlnet_units]
|
||||
self.scales = [unit.scale for unit in controlnet_units]
|
||||
|
||||
def process_image(self, image, return_image=False):
|
||||
processed_image = [
|
||||
processor(image)
|
||||
for processor in self.processors
|
||||
]
|
||||
if return_image:
|
||||
return processed_image
|
||||
processed_image = torch.concat([
|
||||
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
||||
for image_ in processed_image
|
||||
], dim=0)
|
||||
return processed_image
|
||||
|
||||
def __call__(self, sample, timestep, encoder_hidden_states, conditionings):
|
||||
res_stack = None
|
||||
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
||||
res_stack_ = model(sample, timestep, encoder_hidden_states, conditioning)
|
||||
res_stack_ = [res * scale for res in res_stack_]
|
||||
if res_stack is None:
|
||||
res_stack = res_stack_
|
||||
else:
|
||||
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
||||
return res_stack
|
||||
50
diffsynth/controlnets/processors.py
Normal file
50
diffsynth/controlnets/processors.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
from controlnet_aux.processor import (
|
||||
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
||||
)
|
||||
|
||||
|
||||
Processor_id: TypeAlias = Literal[
|
||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
||||
]
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=512):
|
||||
if processor_id == "canny":
|
||||
self.processor = CannyDetector()
|
||||
elif processor_id == "depth":
|
||||
self.processor = MidasDetector.from_pretrained(model_path)
|
||||
elif processor_id == "softedge":
|
||||
self.processor = HEDdetector.from_pretrained(model_path)
|
||||
elif processor_id == "lineart":
|
||||
self.processor = LineartDetector.from_pretrained(model_path)
|
||||
elif processor_id == "lineart_anime":
|
||||
self.processor = LineartAnimeDetector.from_pretrained(model_path)
|
||||
elif processor_id == "openpose":
|
||||
self.processor = OpenposeDetector.from_pretrained(model_path)
|
||||
elif processor_id == "tile":
|
||||
self.processor = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||
|
||||
self.processor_id = processor_id
|
||||
self.detect_resolution = detect_resolution
|
||||
|
||||
def __call__(self, image):
|
||||
width, height = image.size
|
||||
if self.processor_id == "openpose":
|
||||
kwargs = {
|
||||
"include_body": True,
|
||||
"include_hand": True,
|
||||
"include_face": True
|
||||
}
|
||||
else:
|
||||
kwargs = {}
|
||||
if self.processor is not None:
|
||||
image = self.processor(image, detect_resolution=self.detect_resolution, image_resolution=min(width, height), **kwargs)
|
||||
image = image.resize((width, height))
|
||||
return image
|
||||
|
||||
Reference in New Issue
Block a user