mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
support dpo training
This commit is contained in:
@@ -82,7 +82,7 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
return {**inputs_shared, **inputs_posi}
|
||||
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
def forward(self, data, inputs=None, **kwargs):
|
||||
if inputs is None: inputs = self.forward_preprocess(data)
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
|
||||
Reference in New Issue
Block a user