support dpo training

This commit is contained in:
mi804
2025-09-22 10:14:17 +08:00
parent b0abdaffb4
commit bf7b339efb
7 changed files with 213 additions and 6 deletions

View File

@@ -83,7 +83,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None, return_inputs=False):
def forward(self, data, inputs=None, return_inputs=False, **kwargs):
# Inputs
if inputs is None:
inputs = self.forward_preprocess(data)