Merge pull request #269 from modelscope/fix_image_resize

fix_image_resize
This commit is contained in:
Zhongjie Duan
2024-11-18 19:24:57 +08:00
committed by GitHub

View File

@@ -1,4 +1,4 @@
import torch, os import torch, os, torchvision
from torchvision import transforms from torchvision import transforms
import pandas as pd import pandas as pd
from PIL import Image from PIL import Image
@@ -11,9 +11,10 @@ class TextImageDataset(torch.utils.data.Dataset):
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv")) metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]] self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list() self.text = metadata["text"].to_list()
self.height = height
self.width = width
self.image_processor = transforms.Compose( self.image_processor = transforms.Compose(
[ [
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)), transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x), transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(), transforms.ToTensor(),
@@ -27,6 +28,11 @@ class TextImageDataset(torch.utils.data.Dataset):
data_id = (data_id + index) % len(self.path) # For fixed seed. data_id = (data_id + index) % len(self.path) # For fixed seed.
text = self.text[data_id] text = self.text[data_id]
image = Image.open(self.path[data_id]).convert("RGB") image = Image.open(self.path[data_id]).convert("RGB")
target_height, target_width = self.height, self.width
width, height = image.size
scale = max(target_width / width, target_height / height)
shape = [round(height*scale),round(width*scale)]
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
image = self.image_processor(image) image = self.image_processor(image)
return {"text": text, "image": image} return {"text": text, "image": image}