fix_image_resize

This commit is contained in:
tc2000731
2024-11-18 18:34:21 +08:00
parent 095e8a3de8
commit ba58f1bc0b

View File

@@ -1,4 +1,4 @@
import torch, os import torch, os, torchvision
from torchvision import transforms from torchvision import transforms
import pandas as pd import pandas as pd
from PIL import Image from PIL import Image
@@ -11,9 +11,10 @@ class TextImageDataset(torch.utils.data.Dataset):
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv")) metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]] self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list() self.text = metadata["text"].to_list()
self.height = height
self.width = width
self.image_processor = transforms.Compose( self.image_processor = transforms.Compose(
[ [
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)), transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x), transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(), transforms.ToTensor(),
@@ -27,6 +28,11 @@ class TextImageDataset(torch.utils.data.Dataset):
data_id = (data_id + index) % len(self.path) # For fixed seed. data_id = (data_id + index) % len(self.path) # For fixed seed.
text = self.text[data_id] text = self.text[data_id]
image = Image.open(self.path[data_id]).convert("RGB") image = Image.open(self.path[data_id]).convert("RGB")
target_height, target_width = self.height, self.width
width, height = image.size
scale = max(target_width / width, target_height / height)
shape = [round(height*scale),round(width*scale)]
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
image = self.image_processor(image) image = self.image_processor(image)
return {"text": text, "image": image} return {"text": text, "image": image}