skip bad files

This commit is contained in:
Artiprocher
2025-03-19 14:49:18 +08:00
parent 9d8130b48d
commit 6cd032e846

View File

@@ -193,11 +193,15 @@ class TensorDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
while True:
try:
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
except:
continue
def __len__(self):