From 260e32217f649d1dc8a267cc303e42b3a409a129 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 2 Sep 2025 13:14:08 +0800 Subject: [PATCH] unified dataset --- diffsynth/trainers/unified_dataset.py | 334 ++++++++++++++++++++ examples/flux/model_training/train.py | 18 +- examples/qwen_image/model_training/train.py | 18 +- examples/wanvideo/model_training/train.py | 21 +- 4 files changed, 385 insertions(+), 6 deletions(-) create mode 100644 diffsynth/trainers/unified_dataset.py diff --git a/diffsynth/trainers/unified_dataset.py b/diffsynth/trainers/unified_dataset.py new file mode 100644 index 0000000..fc40eaa --- /dev/null +++ b/diffsynth/trainers/unified_dataset.py @@ -0,0 +1,334 @@ +import torch, torchvision, imageio, os, json, pandas +import imageio.v3 as iio +from PIL import Image + + + +class DataProcessingPipeline: + def __init__(self, operators=None): + self.operators: list[DataProcessingOperator] = [] if operators is None else operators + + def __call__(self, data): + for operator in self.operators: + data = operator(data) + return data + + def __rshift__(self, pipe): + if isinstance(pipe, DataProcessingOperator): + pipe = DataProcessingPipeline([pipe]) + return DataProcessingPipeline(self.operators + pipe.operators) + + + +class DataProcessingOperator: + def __call__(self, data): + raise NotImplementedError("DataProcessingOperator cannot be called directly.") + + def __rshift__(self, pipe): + if isinstance(pipe, DataProcessingOperator): + pipe = DataProcessingPipeline([pipe]) + return DataProcessingPipeline([self]).__rshift__(pipe) + + + +class DataProcessingOperatorRaw(DataProcessingOperator): + def __call__(self, data): + return data + + + +class ToInt(DataProcessingOperator): + def __call__(self, data): + return int(data) + + + +class ToFloat(DataProcessingOperator): + def __call__(self, data): + return float(data) + + + +class ToStr(DataProcessingOperator): + def __init__(self, none_value=""): + self.none_value = none_value + + def __call__(self, data): + if data is None: data = self.none_value + return str(data) + + + +class LoadImage(DataProcessingOperator): + def __init__(self, convert_RGB=True): + self.convert_RGB = convert_RGB + + def __call__(self, data: str): + image = Image.open(data) + if self.convert_RGB: image = image.convert("RGB") + return image + + + +class ImageCropAndResize(DataProcessingOperator): + def __init__(self, height, width, max_pixels, height_division_factor, width_division_factor): + self.height = height + self.width = width + self.max_pixels = max_pixels + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + def get_height_width(self, image): + if self.height is None or self.width is None: + width, height = image.size + if width * height > self.max_pixels: + scale = (width * height / self.max_pixels) ** 0.5 + height, width = int(height / scale), int(width / scale) + height = height // self.height_division_factor * self.height_division_factor + width = width // self.width_division_factor * self.width_division_factor + else: + height, width = self.height, self.width + return height, width + + + def __call__(self, data: Image.Image): + image = self.crop_and_resize(data, *self.get_height_width(data)) + return image + + + +class ToList(DataProcessingOperator): + def __call__(self, data): + return [data] + + + +class LoadVideo(DataProcessingOperator): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x): + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + + def get_num_frames(self, reader): + num_frames = self.num_frames + if int(reader.count_frames()) < num_frames: + num_frames = int(reader.count_frames()) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + def __call__(self, data: str): + reader = imageio.get_reader(data) + num_frames = self.get_num_frames(reader) + frames = [] + for frame_id in range(num_frames): + frame = reader.get_data(frame_id) + frame = Image.fromarray(frame) + frame = self.frame_processor(frame) + frames.append(frame) + reader.close() + return frames + + + +class SequencialProcess(DataProcessingOperator): + def __init__(self, operator=lambda x: x): + self.operator = operator + + def __call__(self, data): + return [self.operator(i) for i in data] + + + +class LoadGIF(DataProcessingOperator): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x): + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + + def get_num_frames(self, path): + num_frames = self.num_frames + images = iio.imread(path, mode="RGB") + if len(images) < num_frames: + num_frames = len(images) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + def __call__(self, data: str): + num_frames = self.get_num_frames(data) + frames = [] + images = iio.imread(data, mode="RGB") + for img in images: + frame = Image.fromarray(img) + frame = self.frame_processor(frame) + frames.append(frame) + if len(frames) >= num_frames: + break + return frames + + + +class RouteByExtensionName(DataProcessingOperator): + def __init__(self, operator_map): + self.operator_map = operator_map + + def __call__(self, data: str): + file_ext_name = data.split(".")[-1].lower() + for ext_names, operator in self.operator_map: + if ext_names is None or file_ext_name in ext_names: + return operator(data) + raise ValueError(f"Unsupported file: {data}") + + + +class RouteByType(DataProcessingOperator): + def __init__(self, operator_map): + self.operator_map = operator_map + + def __call__(self, data): + for dtype, operator in self.operator_map: + if dtype is None or isinstance(data, dtype): + return operator(data) + raise ValueError(f"Unsupported data: {data}") + + + +class LoadTorchPickle(DataProcessingOperator): + def __init__(self, map_location="cpu"): + self.map_location = map_location + + def __call__(self, data): + return torch.load(data, map_location=self.map_location) + + + +class ToAbsolutePath(DataProcessingOperator): + def __init__(self, base_path=""): + self.base_path = base_path + + def __call__(self, data): + return os.path.join(self.base_path, data) + + + +class UnifiedDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path=None, metadata_path=None, + repeat=1, + data_file_keys=tuple(), + main_data_operator=lambda x: x, + special_operator_map=None, + ): + self.base_path = base_path + self.metadata_path = metadata_path + self.repeat = repeat + self.data_file_keys = data_file_keys + self.main_data_operator = main_data_operator + self.cached_data_operator = LoadTorchPickle() + self.special_operator_map = {} if special_operator_map is None else special_operator_map + self.data = [] + self.cached_data = [] + self.load_from_cache = metadata_path is None + self.load_metadata(metadata_path) + + @staticmethod + def default_image_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)), + (list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))), + ]) + + @staticmethod + def default_video_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + num_frames=81, time_division_factor=4, time_division_remainder=1, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ + (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), + (("gif",), LoadGIF(num_frames, time_division_factor, time_division_remainder) >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)), + (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + ])), + ]) + + def search_for_cached_data_files(self, path): + for file_name in os.listdir(path): + subpath = os.path.join(path, file_name) + if os.path.isdir(subpath): + self.search_for_cached_data_files(subpath) + elif subpath.endswith(".pth"): + self.cached_data.append(subpath) + + def load_metadata(self, metadata_path): + if metadata_path is None: + print("No metadata_path. Searching for cached data files.") + self.search_for_cached_data_files(self.base_path) + print(f"{len(self.cached_data)} cached data files found.") + elif metadata_path.endswith(".json"): + with open(metadata_path, "r") as f: + metadata = json.load(f) + self.data = metadata + elif metadata_path.endswith(".jsonl"): + metadata = [] + with open(metadata_path, 'r') as f: + for line in f: + metadata.append(json.loads(line.strip())) + self.data = metadata + else: + metadata = pandas.read_csv(metadata_path) + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + + def __getitem__(self, data_id): + if self.load_from_cache: + data = self.cached_data[data_id % len(self.data)].copy() + data = self.cached_data_operator(data) + else: + data = self.data[data_id % len(self.data)].copy() + for key in self.data_file_keys: + if key in data: + if key in self.special_operator_map: + data[key] = self.special_operator_map[key] + elif key in self.data_file_keys: + data[key] = self.main_data_operator(data[key]) + return data + + def __len__(self): + if self.load_from_cache: + return len(self.cached_data) * self.repeat + else: + return len(self.data) * self.repeat + + def check_data_equal(self, data1, data2): + # Debug only + if len(data1) != len(data2): + return False + for k in data1: + if data1[k] != data2[k]: + return False + return True diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index 46eac56..a0db5c4 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -1,8 +1,9 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, flux_parser from diffsynth.models.lora import FluxLoRAConverter +from diffsynth.trainers.unified_dataset import UnifiedDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -106,7 +107,20 @@ class FluxTrainingModule(DiffusionTrainingModule): if __name__ == "__main__": parser = flux_parser() args = parser.parse_args() - dataset = ImageDataset(args=args) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) model = FluxTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index ee6752d..5faa763 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -2,7 +2,8 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.pipelines.flux_image_new import ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, qwen_image_parser +from diffsynth.trainers.unified_dataset import UnifiedDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -119,7 +120,20 @@ class QwenImageTrainingModule(DiffusionTrainingModule): if __name__ == "__main__": parser = qwen_image_parser() args = parser.parse_args() - dataset = ImageDataset(args=args) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) model = QwenImageTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index f2f437e..7df70da 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -1,7 +1,8 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser +from diffsynth.trainers.unified_dataset import UnifiedDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -112,7 +113,23 @@ class WanTrainingModule(DiffusionTrainingModule): if __name__ == "__main__": parser = wan_parser() args = parser.parse_args() - dataset = VideoDataset(args=args) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_video_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + num_frames=args.num_frames, + time_division_factor=4, + time_division_remainder=1, + ), + ) model = WanTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths,