skip bad files

This commit is contained in:
Artiprocher
2025-03-19 14:49:18 +08:00
parent 9d8130b48d
commit 6cd032e846

View File

@@ -193,11 +193,15 @@ class TensorDataset(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0] while True:
data_id = (data_id + index) % len(self.path) # For fixed seed. try:
path = self.path[data_id] data_id = torch.randint(0, len(self.path), (1,))[0]
data = torch.load(path, weights_only=True, map_location="cpu") data_id = (data_id + index) % len(self.path) # For fixed seed.
return data path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
except:
continue
def __len__(self): def __len__(self):