Merge pull request #87 from Lupino/main

pass device to processors Annotator
This commit is contained in:
Artiprocher
2024-07-04 10:34:40 +08:00
committed by GitHub
3 changed files with 10 additions and 10 deletions

View File

@@ -12,24 +12,24 @@ Processor_id: TypeAlias = Literal[
] ]
class Annotator: class Annotator:
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None): def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
if processor_id == "canny": if processor_id == "canny":
self.processor = CannyDetector() self.processor = CannyDetector()
elif processor_id == "depth": elif processor_id == "depth":
self.processor = MidasDetector.from_pretrained(model_path).to("cuda") self.processor = MidasDetector.from_pretrained(model_path).to(device)
elif processor_id == "softedge": elif processor_id == "softedge":
self.processor = HEDdetector.from_pretrained(model_path).to("cuda") self.processor = HEDdetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart": elif processor_id == "lineart":
self.processor = LineartDetector.from_pretrained(model_path).to("cuda") self.processor = LineartDetector.from_pretrained(model_path).to(device)
elif processor_id == "lineart_anime": elif processor_id == "lineart_anime":
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda") self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
elif processor_id == "openpose": elif processor_id == "openpose":
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda") self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
elif processor_id == "tile": elif processor_id == "tile":
self.processor = None self.processor = None
else: else:
raise ValueError(f"Unsupported processor_id: {processor_id}") raise ValueError(f"Unsupported processor_id: {processor_id}")
self.processor_id = processor_id self.processor_id = processor_id
self.detect_resolution = detect_resolution self.detect_resolution = detect_resolution

View File

@@ -39,14 +39,14 @@ class SDImagePipeline(torch.nn.Module):
controlnet_units = [] controlnet_units = []
for config in controlnet_config_units: for config in controlnet_config_units:
controlnet_unit = ControlNetUnit( controlnet_unit = ControlNetUnit(
Annotator(config.processor_id), Annotator(config.processor_id, device=self.device),
model_manager.get_model_with_model_path(config.model_path), model_manager.get_model_with_model_path(config.model_path),
config.scale config.scale
) )
controlnet_units.append(controlnet_unit) controlnet_units.append(controlnet_unit)
self.controlnet = MultiControlNetManager(controlnet_units) self.controlnet = MultiControlNetManager(controlnet_units)
def fetch_ipadapter(self, model_manager: ModelManager): def fetch_ipadapter(self, model_manager: ModelManager):
if "ipadapter" in model_manager.model: if "ipadapter" in model_manager.model:
self.ipadapter = model_manager.ipadapter self.ipadapter = model_manager.ipadapter

View File

@@ -89,7 +89,7 @@ class SDVideoPipeline(torch.nn.Module):
controlnet_units = [] controlnet_units = []
for config in controlnet_config_units: for config in controlnet_config_units:
controlnet_unit = ControlNetUnit( controlnet_unit = ControlNetUnit(
Annotator(config.processor_id), Annotator(config.processor_id, device=self.device),
model_manager.get_model_with_model_path(config.model_path), model_manager.get_model_with_model_path(config.model_path),
config.scale config.scale
) )