mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 18:58:11 +00:00
ignore metadata
This commit is contained in:
@@ -167,7 +167,8 @@ class LightningModelForDataProcess(pl.LightningModule):
|
||||
|
||||
|
||||
class TensorDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, steps_per_epoch, redirected_tensor_path=None):
|
||||
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
print(len(self.path), "videos in metadata.")
|
||||
@@ -181,6 +182,9 @@ class TensorDataset(torch.utils.data.Dataset):
|
||||
if os.path.exists(path + ".tensors.pth"):
|
||||
cached_path.append(path + ".tensors.pth")
|
||||
self.path = cached_path
|
||||
else:
|
||||
print("Cannot find metadata.csv. Trying to search for tensor files.")
|
||||
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
|
||||
print(len(self.path), "tensors cached in metadata.")
|
||||
assert len(self.path) > 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user