mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
support direct distill
This commit is contained in:
@@ -19,6 +19,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
enable_fp8_training=False,
|
||||
task="sft",
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
@@ -38,6 +39,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
self.task = task
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
@@ -82,11 +84,21 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
|
||||
|
||||
def forward(self, data, inputs=None, return_inputs=False):
|
||||
# Inputs
|
||||
if inputs is None: inputs = self.forward_preprocess(data)
|
||||
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
|
||||
if return_inputs: return inputs
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
|
||||
# Loss
|
||||
if self.task == "sft":
|
||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||
loss = self.pipe.training_loss(**models, **inputs)
|
||||
elif self.task == "data_process":
|
||||
loss = inputs
|
||||
elif self.task == "direct_distill":
|
||||
loss = self.pipe.direct_distill_loss(**inputs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
||||
return loss
|
||||
|
||||
|
||||
@@ -122,10 +134,12 @@ if __name__ == "__main__":
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
enable_fp8_training=args.enable_fp8_training,
|
||||
task=args.task,
|
||||
)
|
||||
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
||||
launcher_map = {
|
||||
"sft": launch_training_task,
|
||||
"data_process": launch_data_process_task
|
||||
"data_process": launch_data_process_task,
|
||||
"direct_distill": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](dataset, model, model_logger, args=args)
|
||||
|
||||
Reference in New Issue
Block a user