support dpo training

This commit is contained in:
mi804
2025-09-22 10:14:17 +08:00
parent b0abdaffb4
commit bf7b339efb
7 changed files with 213 additions and 6 deletions

View File

@@ -75,7 +75,7 @@ class FluxTrainingModule(DiffusionTrainingModule):
return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None):
def forward(self, data, inputs=None, **kwargs):
if inputs is None: inputs = self.forward_preprocess(data)
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
loss = self.pipe.training_loss(**models, **inputs)