update to 2.0.3

This commit is contained in:
Artiprocher
2026-01-21 20:22:43 +08:00
parent 37fbb3248a
commit 030ebe649a
2 changed files with 6 additions and 2 deletions

View File

@@ -10,6 +10,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
data_file_keys=tuple(), data_file_keys=tuple(),
main_data_operator=lambda x: x, main_data_operator=lambda x: x,
special_operator_map=None, special_operator_map=None,
max_data_items=None,
): ):
self.base_path = base_path self.base_path = base_path
self.metadata_path = metadata_path self.metadata_path = metadata_path
@@ -18,6 +19,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
self.main_data_operator = main_data_operator self.main_data_operator = main_data_operator
self.cached_data_operator = LoadTorchPickle() self.cached_data_operator = LoadTorchPickle()
self.special_operator_map = {} if special_operator_map is None else special_operator_map self.special_operator_map = {} if special_operator_map is None else special_operator_map
self.max_data_items = max_data_items
self.data = [] self.data = []
self.cached_data = [] self.cached_data = []
self.load_from_cache = metadata_path is None self.load_from_cache = metadata_path is None
@@ -97,7 +99,9 @@ class UnifiedDataset(torch.utils.data.Dataset):
return data return data
def __len__(self): def __len__(self):
if self.load_from_cache: if self.max_data_items is not None:
return self.max_data_items
elif self.load_from_cache:
return len(self.cached_data) * self.repeat return len(self.cached_data) * self.repeat
else: else:
return len(self.data) * self.repeat return len(self.data) * self.repeat

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "diffsynth" name = "diffsynth"
version = "2.0.2" version = "2.0.3"
description = "Enjoy the magic of Diffusion models!" description = "Enjoy the magic of Diffusion models!"
authors = [{name = "ModelScope Team"}] authors = [{name = "ModelScope Team"}]
license = {text = "Apache-2.0"} license = {text = "Apache-2.0"}