From ba58f1bc0bd6895dcaf6db0c6d82161fb353f3e7 Mon Sep 17 00:00:00 2001 From: tc2000731 Date: Mon, 18 Nov 2024 18:34:21 +0800 Subject: [PATCH] fix_image_resize --- diffsynth/data/simple_text_image.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/diffsynth/data/simple_text_image.py b/diffsynth/data/simple_text_image.py index 1737098..7a9525e 100644 --- a/diffsynth/data/simple_text_image.py +++ b/diffsynth/data/simple_text_image.py @@ -1,4 +1,4 @@ -import torch, os +import torch, os, torchvision from torchvision import transforms import pandas as pd from PIL import Image @@ -11,9 +11,10 @@ class TextImageDataset(torch.utils.data.Dataset): metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv")) self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]] self.text = metadata["text"].to_list() + self.height = height + self.width = width self.image_processor = transforms.Compose( [ - transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)), transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(), @@ -27,6 +28,11 @@ class TextImageDataset(torch.utils.data.Dataset): data_id = (data_id + index) % len(self.path) # For fixed seed. text = self.text[data_id] image = Image.open(self.path[data_id]).convert("RGB") + target_height, target_width = self.height, self.width + width, height = image.size + scale = max(target_width / width, target_height / height) + shape = [round(height*scale),round(width*scale)] + image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR) image = self.image_processor(image) return {"text": text, "image": image}