mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 02:38:10 +00:00
ignore metadata
This commit is contained in:
@@ -167,20 +167,24 @@ class LightningModelForDataProcess(pl.LightningModule):
|
|||||||
|
|
||||||
|
|
||||||
class TensorDataset(torch.utils.data.Dataset):
|
class TensorDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, base_path, metadata_path, steps_per_epoch, redirected_tensor_path=None):
|
def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None):
|
||||||
metadata = pd.read_csv(metadata_path)
|
if os.path.exists(metadata_path):
|
||||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
metadata = pd.read_csv(metadata_path)
|
||||||
print(len(self.path), "videos in metadata.")
|
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||||
if redirected_tensor_path is None:
|
print(len(self.path), "videos in metadata.")
|
||||||
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
if redirected_tensor_path is None:
|
||||||
|
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
|
||||||
|
else:
|
||||||
|
cached_path = []
|
||||||
|
for path in self.path:
|
||||||
|
path = path.replace("/", "_").replace("\\", "_")
|
||||||
|
path = os.path.join(redirected_tensor_path, path)
|
||||||
|
if os.path.exists(path + ".tensors.pth"):
|
||||||
|
cached_path.append(path + ".tensors.pth")
|
||||||
|
self.path = cached_path
|
||||||
else:
|
else:
|
||||||
cached_path = []
|
print("Cannot find metadata.csv. Trying to search for tensor files.")
|
||||||
for path in self.path:
|
self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")]
|
||||||
path = path.replace("/", "_").replace("\\", "_")
|
|
||||||
path = os.path.join(redirected_tensor_path, path)
|
|
||||||
if os.path.exists(path + ".tensors.pth"):
|
|
||||||
cached_path.append(path + ".tensors.pth")
|
|
||||||
self.path = cached_path
|
|
||||||
print(len(self.path), "tensors cached in metadata.")
|
print(len(self.path), "tensors cached in metadata.")
|
||||||
assert len(self.path) > 0
|
assert len(self.path) > 0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user