mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
support dpo
This commit is contained in:
@@ -2,7 +2,7 @@ import torch, os, json
|
||||
from diffsynth import load_state_dict
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
|
||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task, DPOLoss
|
||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
@@ -84,24 +84,29 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
|
||||
|
||||
def forward(self, data, inputs=None, return_inputs=False):
|
||||
# Inputs
|
||||
if inputs is None:
|
||||
inputs = self.forward_preprocess(data)
|
||||
# DPO (DPO requires a special training loss)
|
||||
if self.task == "dpo":
|
||||
loss = DPOLoss().loss(self, data)
|
||||
return loss
|
||||
else:
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
if return_inputs: return inputs
|
||||
|
||||
# Loss
|
||||
if self.task == "sft":
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
elif self.task == "data_process":
|
||||
loss = inputs
|
||||
elif self.task == "direct_distill":
|
||||
loss = self.pipe.direct_distill_loss(**inputs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
||||
return loss
|
||||
# Inputs
|
||||
if inputs is None:
|
||||
inputs = self.forward_preprocess(data)
|
||||
else:
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
if return_inputs: return inputs
|
||||
|
||||
# Loss
|
||||
if self.task == "sft":
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
elif self.task == "data_process":
|
||||
loss = inputs
|
||||
elif self.task == "direct_distill":
|
||||
loss = self.pipe.direct_distill_loss(**inputs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
@@ -143,5 +148,6 @@ if __name__ == "__main__":
|
||||
"sft": launch_training_task,
|
||||
"data_process": launch_data_process_task,
|
||||
"direct_distill": launch_training_task,
|
||||
"dpo": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](dataset, model, model_logger, args=args)
|
||||
|
||||
Reference in New Issue
Block a user