import torch import numpy as np from PIL import Image class BasePipeline(torch.nn.Module): def __init__(self, device="cuda", torch_dtype=torch.float16): super().__init__() self.device = device self.torch_dtype = torch_dtype def preprocess_image(self, image): image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) return image def preprocess_images(self, images): return [self.preprocess_image(image) for image in images] def vae_output_to_image(self, vae_output): image = vae_output[0].cpu().permute(1, 2, 0).numpy() image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) return image def vae_output_to_video(self, vae_output): video = vae_output.cpu().permute(1, 2, 0).numpy() video = [Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) for image in video] return video