diff --git a/diffsynth/controlnets/processors.py b/diffsynth/controlnets/processors.py index a378842..1d23c73 100644 --- a/diffsynth/controlnets/processors.py +++ b/diffsynth/controlnets/processors.py @@ -12,24 +12,24 @@ Processor_id: TypeAlias = Literal[ ] class Annotator: - def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None): + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'): if processor_id == "canny": self.processor = CannyDetector() elif processor_id == "depth": - self.processor = MidasDetector.from_pretrained(model_path).to("cuda") + self.processor = MidasDetector.from_pretrained(model_path).to(device) elif processor_id == "softedge": - self.processor = HEDdetector.from_pretrained(model_path).to("cuda") + self.processor = HEDdetector.from_pretrained(model_path).to(device) elif processor_id == "lineart": - self.processor = LineartDetector.from_pretrained(model_path).to("cuda") + self.processor = LineartDetector.from_pretrained(model_path).to(device) elif processor_id == "lineart_anime": - self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda") + self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device) elif processor_id == "openpose": - self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda") + self.processor = OpenposeDetector.from_pretrained(model_path).to(device) elif processor_id == "tile": self.processor = None else: raise ValueError(f"Unsupported processor_id: {processor_id}") - + self.processor_id = processor_id self.detect_resolution = detect_resolution diff --git a/diffsynth/pipelines/stable_diffusion.py b/diffsynth/pipelines/stable_diffusion.py index 4f71f89..d3184bf 100644 --- a/diffsynth/pipelines/stable_diffusion.py +++ b/diffsynth/pipelines/stable_diffusion.py @@ -39,14 +39,14 @@ class SDImagePipeline(torch.nn.Module): controlnet_units = [] for config in controlnet_config_units: controlnet_unit = ControlNetUnit( - Annotator(config.processor_id), + Annotator(config.processor_id, device=self.device), model_manager.get_model_with_model_path(config.model_path), config.scale ) controlnet_units.append(controlnet_unit) self.controlnet = MultiControlNetManager(controlnet_units) - + def fetch_ipadapter(self, model_manager: ModelManager): if "ipadapter" in model_manager.model: self.ipadapter = model_manager.ipadapter diff --git a/diffsynth/pipelines/stable_diffusion_video.py b/diffsynth/pipelines/stable_diffusion_video.py index b204cad..7894d23 100644 --- a/diffsynth/pipelines/stable_diffusion_video.py +++ b/diffsynth/pipelines/stable_diffusion_video.py @@ -89,7 +89,7 @@ class SDVideoPipeline(torch.nn.Module): controlnet_units = [] for config in controlnet_config_units: controlnet_unit = ControlNetUnit( - Annotator(config.processor_id), + Annotator(config.processor_id, device=self.device), model_manager.get_model_with_model_path(config.model_path), config.scale )