mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
Merge pull request #269 from modelscope/fix_image_resize
fix_image_resize
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
import torch, os
|
import torch, os, torchvision
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -11,9 +11,10 @@ class TextImageDataset(torch.utils.data.Dataset):
|
|||||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
||||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
self.text = metadata["text"].to_list()
|
self.text = metadata["text"].to_list()
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
self.image_processor = transforms.Compose(
|
self.image_processor = transforms.Compose(
|
||||||
[
|
[
|
||||||
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
|
||||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
||||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
@@ -27,6 +28,11 @@ class TextImageDataset(torch.utils.data.Dataset):
|
|||||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||||
text = self.text[data_id]
|
text = self.text[data_id]
|
||||||
image = Image.open(self.path[data_id]).convert("RGB")
|
image = Image.open(self.path[data_id]).convert("RGB")
|
||||||
|
target_height, target_width = self.height, self.width
|
||||||
|
width, height = image.size
|
||||||
|
scale = max(target_width / width, target_height / height)
|
||||||
|
shape = [round(height*scale),round(width*scale)]
|
||||||
|
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
|
||||||
image = self.image_processor(image)
|
image = self.image_processor(image)
|
||||||
return {"text": text, "image": image}
|
return {"text": text, "image": image}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user