From 030ebe649a86652adf25a86947cb8a6c3c3fdf85 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 21 Jan 2026 20:22:43 +0800 Subject: [PATCH] update to 2.0.3 --- diffsynth/core/data/unified_dataset.py | 6 +++++- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/diffsynth/core/data/unified_dataset.py b/diffsynth/core/data/unified_dataset.py index 074208c..46fecd7 100644 --- a/diffsynth/core/data/unified_dataset.py +++ b/diffsynth/core/data/unified_dataset.py @@ -10,6 +10,7 @@ class UnifiedDataset(torch.utils.data.Dataset): data_file_keys=tuple(), main_data_operator=lambda x: x, special_operator_map=None, + max_data_items=None, ): self.base_path = base_path self.metadata_path = metadata_path @@ -18,6 +19,7 @@ class UnifiedDataset(torch.utils.data.Dataset): self.main_data_operator = main_data_operator self.cached_data_operator = LoadTorchPickle() self.special_operator_map = {} if special_operator_map is None else special_operator_map + self.max_data_items = max_data_items self.data = [] self.cached_data = [] self.load_from_cache = metadata_path is None @@ -97,7 +99,9 @@ class UnifiedDataset(torch.utils.data.Dataset): return data def __len__(self): - if self.load_from_cache: + if self.max_data_items is not None: + return self.max_data_items + elif self.load_from_cache: return len(self.cached_data) * self.repeat else: return len(self.data) * self.repeat diff --git a/pyproject.toml b/pyproject.toml index 059e21d..de82279 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "diffsynth" -version = "2.0.2" +version = "2.0.3" description = "Enjoy the magic of Diffusion models!" authors = [{name = "ModelScope Team"}] license = {text = "Apache-2.0"}