mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 10:48:11 +00:00
ignore metadata
This commit is contained in:
@@ -167,7 +167,8 @@ class LightningModelForDataProcess(pl.LightningModule):
|
|||||||
|
|
||||||
|
|
||||||
class TensorDataset(torch.utils.data.Dataset):
|
class TensorDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, base_path, metadata_path, steps_per_epoch, redirected_tensor_path=None):
|
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
metadata = pd.read_csv(metadata_path)
|
metadata = pd.read_csv(metadata_path)
|
||||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
print(len(self.path), "videos in metadata.")
|
print(len(self.path), "videos in metadata.")
|
||||||
@@ -181,6 +182,9 @@ class TensorDataset(torch.utils.data.Dataset):
|
|||||||
if os.path.exists(path + ".tensors.pth"):
|
if os.path.exists(path + ".tensors.pth"):
|
||||||
cached_path.append(path + ".tensors.pth")
|
cached_path.append(path + ".tensors.pth")
|
||||||
self.path = cached_path
|
self.path = cached_path
|
||||||
|
else:
|
||||||
|
print("Cannot find metadata.csv. Trying to search for tensor files.")
|
||||||
|
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
|
||||||
print(len(self.path), "tensors cached in metadata.")
|
print(len(self.path), "tensors cached in metadata.")
|
||||||
assert len(self.path) > 0
|
assert len(self.path) > 0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user