diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index b5c0af0..b171857 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -1,4 +1,4 @@ -import imageio, os, torch, warnings, torchvision, argparse +import imageio, os, torch, warnings, torchvision, argparse, json from peft import LoraConfig, inject_adapter_in_model from PIL import Image import pandas as pd @@ -48,9 +48,14 @@ class ImageDataset(torch.utils.data.Dataset): print("No metadata. Trying to generate it.") metadata = self.generate_metadata(base_path) print(f"{len(metadata)} lines in metadata.") + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + elif metadata_path.endswith(".json"): + with open(metadata_path, "r") as f: + metadata = json.load(f) + self.data = metadata else: metadata = pd.read_csv(metadata_path) - self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] def generate_metadata(self, folder): @@ -177,9 +182,14 @@ class VideoDataset(torch.utils.data.Dataset): print("No metadata. Trying to generate it.") metadata = self.generate_metadata(base_path) print(f"{len(metadata)} lines in metadata.") + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + elif metadata_path.endswith(".json"): + with open(metadata_path, "r") as f: + metadata = json.load(f) + self.data = metadata else: metadata = pd.read_csv(metadata_path) - self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] def generate_metadata(self, folder):