ignore metadata

This commit is contained in:
Artiprocher
2025-03-19 11:36:07 +08:00
parent ce848a3d1a
commit 9d8130b48d

View File

@@ -167,7 +167,8 @@ class LightningModelForDataProcess(pl.LightningModule):
class TensorDataset(torch.utils.data.Dataset): class TensorDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, steps_per_epoch, redirected_tensor_path=None): def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
if os.path.exists(metadata_path):
metadata = pd.read_csv(metadata_path) metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
print(len(self.path), "videos in metadata.") print(len(self.path), "videos in metadata.")
@@ -181,6 +182,9 @@ class TensorDataset(torch.utils.data.Dataset):
if os.path.exists(path + ".tensors.pth"): if os.path.exists(path + ".tensors.pth"):
cached_path.append(path + ".tensors.pth") cached_path.append(path + ".tensors.pth")
self.path = cached_path self.path = cached_path
else:
print("Cannot find metadata.csv. Trying to search for tensor files.")
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
print(len(self.path), "tensors cached in metadata.") print(len(self.path), "tensors cached in metadata.")
assert len(self.path) > 0 assert len(self.path) > 0