mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 09:28:12 +00:00
Merge pull request #87 from Lupino/main
pass device to processors Annotator
This commit is contained in:
@@ -12,24 +12,24 @@ Processor_id: TypeAlias = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
class Annotator:
|
class Annotator:
|
||||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
|
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
|
||||||
if processor_id == "canny":
|
if processor_id == "canny":
|
||||||
self.processor = CannyDetector()
|
self.processor = CannyDetector()
|
||||||
elif processor_id == "depth":
|
elif processor_id == "depth":
|
||||||
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
|
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "softedge":
|
elif processor_id == "softedge":
|
||||||
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
|
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "lineart":
|
elif processor_id == "lineart":
|
||||||
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
|
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "lineart_anime":
|
elif processor_id == "lineart_anime":
|
||||||
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
|
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "openpose":
|
elif processor_id == "openpose":
|
||||||
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
|
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
||||||
elif processor_id == "tile":
|
elif processor_id == "tile":
|
||||||
self.processor = None
|
self.processor = None
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
||||||
|
|
||||||
self.processor_id = processor_id
|
self.processor_id = processor_id
|
||||||
self.detect_resolution = detect_resolution
|
self.detect_resolution = detect_resolution
|
||||||
|
|
||||||
|
|||||||
@@ -39,14 +39,14 @@ class SDImagePipeline(torch.nn.Module):
|
|||||||
controlnet_units = []
|
controlnet_units = []
|
||||||
for config in controlnet_config_units:
|
for config in controlnet_config_units:
|
||||||
controlnet_unit = ControlNetUnit(
|
controlnet_unit = ControlNetUnit(
|
||||||
Annotator(config.processor_id),
|
Annotator(config.processor_id, device=self.device),
|
||||||
model_manager.get_model_with_model_path(config.model_path),
|
model_manager.get_model_with_model_path(config.model_path),
|
||||||
config.scale
|
config.scale
|
||||||
)
|
)
|
||||||
controlnet_units.append(controlnet_unit)
|
controlnet_units.append(controlnet_unit)
|
||||||
self.controlnet = MultiControlNetManager(controlnet_units)
|
self.controlnet = MultiControlNetManager(controlnet_units)
|
||||||
|
|
||||||
|
|
||||||
def fetch_ipadapter(self, model_manager: ModelManager):
|
def fetch_ipadapter(self, model_manager: ModelManager):
|
||||||
if "ipadapter" in model_manager.model:
|
if "ipadapter" in model_manager.model:
|
||||||
self.ipadapter = model_manager.ipadapter
|
self.ipadapter = model_manager.ipadapter
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ class SDVideoPipeline(torch.nn.Module):
|
|||||||
controlnet_units = []
|
controlnet_units = []
|
||||||
for config in controlnet_config_units:
|
for config in controlnet_config_units:
|
||||||
controlnet_unit = ControlNetUnit(
|
controlnet_unit = ControlNetUnit(
|
||||||
Annotator(config.processor_id),
|
Annotator(config.processor_id, device=self.device),
|
||||||
model_manager.get_model_with_model_path(config.model_path),
|
model_manager.get_model_with_model_path(config.model_path),
|
||||||
config.scale
|
config.scale
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user