mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
92 lines
3.5 KiB
Python
92 lines
3.5 KiB
Python
import torch
|
|
import numpy as np
|
|
from .processors import Processor_id
|
|
|
|
|
|
class ControlNetConfigUnit:
|
|
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
|
|
self.processor_id = processor_id
|
|
self.model_path = model_path
|
|
self.scale = scale
|
|
self.skip_processor = skip_processor
|
|
|
|
|
|
class ControlNetUnit:
|
|
def __init__(self, processor, model, scale=1.0):
|
|
self.processor = processor
|
|
self.model = model
|
|
self.scale = scale
|
|
|
|
|
|
class MultiControlNetManager:
|
|
def __init__(self, controlnet_units=[]):
|
|
self.processors = [unit.processor for unit in controlnet_units]
|
|
self.models = [unit.model for unit in controlnet_units]
|
|
self.scales = [unit.scale for unit in controlnet_units]
|
|
|
|
def cpu(self):
|
|
for model in self.models:
|
|
model.cpu()
|
|
|
|
def to(self, device):
|
|
for model in self.models:
|
|
model.to(device)
|
|
for processor in self.processors:
|
|
processor.to(device)
|
|
|
|
def process_image(self, image, processor_id=None):
|
|
if processor_id is None:
|
|
processed_image = [processor(image) for processor in self.processors]
|
|
else:
|
|
processed_image = [self.processors[processor_id](image)]
|
|
processed_image = torch.concat([
|
|
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
|
for image_ in processed_image
|
|
], dim=0)
|
|
return processed_image
|
|
|
|
def __call__(
|
|
self,
|
|
sample, timestep, encoder_hidden_states, conditionings,
|
|
tiled=False, tile_size=64, tile_stride=32, **kwargs
|
|
):
|
|
res_stack = None
|
|
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
|
res_stack_ = model(
|
|
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
|
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
|
processor_id=processor.processor_id
|
|
)
|
|
res_stack_ = [res * scale for res in res_stack_]
|
|
if res_stack is None:
|
|
res_stack = res_stack_
|
|
else:
|
|
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
|
return res_stack
|
|
|
|
|
|
class FluxMultiControlNetManager(MultiControlNetManager):
|
|
def __init__(self, controlnet_units=[]):
|
|
super().__init__(controlnet_units=controlnet_units)
|
|
|
|
def process_image(self, image, processor_id=None):
|
|
if processor_id is None:
|
|
processed_image = [processor(image) for processor in self.processors]
|
|
else:
|
|
processed_image = [self.processors[processor_id](image)]
|
|
return processed_image
|
|
|
|
def __call__(self, conditionings, **kwargs):
|
|
res_stack, single_res_stack = None, None
|
|
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
|
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
|
|
res_stack_ = [res * scale for res in res_stack_]
|
|
single_res_stack_ = [res * scale for res in single_res_stack_]
|
|
if res_stack is None:
|
|
res_stack = res_stack_
|
|
single_res_stack = single_res_stack_
|
|
else:
|
|
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
|
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
|
|
return res_stack, single_res_stack
|