From 6cd032e846398adf8ffd502d7b29ce4966d29602 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 19 Mar 2025 14:49:18 +0800 Subject: [PATCH] skip bad files --- examples/wanvideo/train_wan_t2v.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 2057730..e7b1495 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -193,11 +193,15 @@ class TensorDataset(torch.utils.data.Dataset): def __getitem__(self, index): - data_id = torch.randint(0, len(self.path), (1,))[0] - data_id = (data_id + index) % len(self.path) # For fixed seed. - path = self.path[data_id] - data = torch.load(path, weights_only=True, map_location="cpu") - return data + while True: + try: + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + path = self.path[data_id] + data = torch.load(path, weights_only=True, map_location="cpu") + return data + except: + continue def __len__(self):