qwen-image splited training

This commit is contained in:
Artiprocher
2025-09-02 16:44:14 +08:00
parent 260e32217f
commit b6da77e468
7 changed files with 221 additions and 14 deletions

View File

@@ -111,6 +111,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
def forward(self, data, inputs=None):
if inputs is None: inputs = self.forward_preprocess(data)
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
loss = self.pipe.training_loss(**models, **inputs)
return loss