From 9d8130b48dfcb2dad0c0b34af45caefa435102b9 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 19 Mar 2025 11:36:07 +0800 Subject: [PATCH] ignore metadata --- examples/wanvideo/train_wan_t2v.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 8cab520..2057730 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -167,20 +167,24 @@ class LightningModelForDataProcess(pl.LightningModule): class TensorDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, steps_per_epoch, redirected_tensor_path=None): - metadata = pd.read_csv(metadata_path) - self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] - print(len(self.path), "videos in metadata.") - if redirected_tensor_path is None: - self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] + def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None): + if os.path.exists(metadata_path): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + if redirected_tensor_path is None: + self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] + else: + cached_path = [] + for path in self.path: + path = path.replace("/", "_").replace("\\", "_") + path = os.path.join(redirected_tensor_path, path) + if os.path.exists(path + ".tensors.pth"): + cached_path.append(path + ".tensors.pth") + self.path = cached_path else: - cached_path = [] - for path in self.path: - path = path.replace("/", "_").replace("\\", "_") - path = os.path.join(redirected_tensor_path, path) - if os.path.exists(path + ".tensors.pth"): - cached_path.append(path + ".tensors.pth") - self.path = cached_path + print("Cannot find metadata.csv. Trying to search for tensor files.") + self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")] print(len(self.path), "tensors cached in metadata.") assert len(self.path) > 0