diff --git a/diffsynth/extensions/ESRGAN/__init__.py b/diffsynth/extensions/ESRGAN/__init__.py index 00b90d1..94aff4c 100644 --- a/diffsynth/extensions/ESRGAN/__init__.py +++ b/diffsynth/extensions/ESRGAN/__init__.py @@ -107,6 +107,12 @@ class ESRGAN(torch.nn.Module): @torch.no_grad() def upscale(self, images, batch_size=4, progress_bar=lambda x:x): + if not isinstance(images, list): + images = [images] + is_single_image = True + else: + is_single_image = False + # Preprocess input_tensor = self.process_images(images) @@ -126,4 +132,6 @@ class ESRGAN(torch.nn.Module): # To images output_images = self.decode_images(output_tensor) + if is_single_image: + output_images = output_images[0] return output_images