from .operators import * import torch, json, pandas class UnifiedDataset(torch.utils.data.Dataset): def __init__( self, base_path=None, metadata_path=None, repeat=1, data_file_keys=tuple(), main_data_operator=lambda x: x, special_operator_map=None, ): self.base_path = base_path self.metadata_path = metadata_path self.repeat = repeat self.data_file_keys = data_file_keys self.main_data_operator = main_data_operator self.cached_data_operator = LoadTorchPickle() self.special_operator_map = {} if special_operator_map is None else special_operator_map self.data = [] self.cached_data = [] self.load_from_cache = metadata_path is None self.load_metadata(metadata_path) @staticmethod def default_image_operator( base_path="", max_pixels=1920*1080, height=None, width=None, height_division_factor=16, width_division_factor=16, ): return RouteByType(operator_map=[ (str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)), (list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))), ]) @staticmethod def default_video_operator( base_path="", max_pixels=1920*1080, height=None, width=None, height_division_factor=16, width_division_factor=16, num_frames=81, time_division_factor=4, time_division_remainder=1, ): return RouteByType(operator_map=[ (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), (("gif",), LoadGIF( num_frames, time_division_factor, time_division_remainder, frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), )), (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( num_frames, time_division_factor, time_division_remainder, frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), )), ])), ]) def search_for_cached_data_files(self, path): for file_name in os.listdir(path): subpath = os.path.join(path, file_name) if os.path.isdir(subpath): self.search_for_cached_data_files(subpath) elif subpath.endswith(".pth"): self.cached_data.append(subpath) def load_metadata(self, metadata_path): if metadata_path is None: print("No metadata_path. Searching for cached data files.") self.search_for_cached_data_files(self.base_path) print(f"{len(self.cached_data)} cached data files found.") elif metadata_path.endswith(".json"): with open(metadata_path, "r") as f: metadata = json.load(f) self.data = metadata elif metadata_path.endswith(".jsonl"): metadata = [] with open(metadata_path, 'r') as f: for line in f: metadata.append(json.loads(line.strip())) self.data = metadata else: metadata = pandas.read_csv(metadata_path) self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] def __getitem__(self, data_id): if self.load_from_cache: data = self.cached_data[data_id % len(self.cached_data)] data = self.cached_data_operator(data) else: data = self.data[data_id % len(self.data)].copy() for key in self.data_file_keys: if key in data: if key in self.special_operator_map: data[key] = self.special_operator_map[key](data[key]) elif key in self.data_file_keys: data[key] = self.main_data_operator(data[key]) return data def __len__(self): if self.load_from_cache: return len(self.cached_data) * self.repeat else: return len(self.data) * self.repeat def check_data_equal(self, data1, data2): # Debug only if len(data1) != len(data2): return False for k in data1: if data1[k] != data2[k]: return False return True