qwen-image splited training

This commit is contained in:
Artiprocher
2025-09-02 16:44:14 +08:00
parent 260e32217f
commit b6da77e468
7 changed files with 221 additions and 14 deletions

View File

@@ -214,7 +214,7 @@ class LoadTorchPickle(DataProcessingOperator):
self.map_location = map_location
def __call__(self, data):
return torch.load(data, map_location=self.map_location)
return torch.load(data, map_location=self.map_location, weights_only=False)
@@ -306,7 +306,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
def __getitem__(self, data_id):
if self.load_from_cache:
data = self.cached_data[data_id % len(self.data)].copy()
data = self.cached_data[data_id % len(self.cached_data)]
data = self.cached_data_operator(data)
else:
data = self.data[data_id % len(self.data)].copy()