mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
ignore metadata
This commit is contained in:
@@ -167,20 +167,24 @@ class LightningModelForDataProcess(pl.LightningModule):
|
||||
|
||||
|
||||
class TensorDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, steps_per_epoch, redirected_tensor_path=None):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
print(len(self.path), "videos in metadata.")
|
||||
if redirected_tensor_path is None:
|
||||
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
print(len(self.path), "videos in metadata.")
|
||||
if redirected_tensor_path is None:
|
||||
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||
else:
|
||||
cached_path = []
|
||||
for path in self.path:
|
||||
path = path.replace("/", "_").replace("\\", "_")
|
||||
path = os.path.join(redirected_tensor_path, path)
|
||||
if os.path.exists(path + ".tensors.pth"):
|
||||
cached_path.append(path + ".tensors.pth")
|
||||
self.path = cached_path
|
||||
else:
|
||||
cached_path = []
|
||||
for path in self.path:
|
||||
path = path.replace("/", "_").replace("\\", "_")
|
||||
path = os.path.join(redirected_tensor_path, path)
|
||||
if os.path.exists(path + ".tensors.pth"):
|
||||
cached_path.append(path + ".tensors.pth")
|
||||
self.path = cached_path
|
||||
print("Cannot find metadata.csv. Trying to search for tensor files.")
|
||||
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
|
||||
print(len(self.path), "tensors cached in metadata.")
|
||||
assert len(self.path) > 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user