z image distill

This commit is contained in:
Artiprocher
2025-12-03 11:20:49 +08:00
parent 5065c9ef6a
commit 4a80e9c179
8 changed files with 185 additions and 26 deletions

View File

@@ -1,4 +1,4 @@
import torch, os, argparse, accelerate
import torch, os, argparse, accelerate, copy
from diffsynth.core import UnifiedDataset
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
from diffsynth.diffusion import *
@@ -50,6 +50,13 @@ class ZImageTrainingModule(DiffusionTrainingModule):
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
}
if task == "trajectory_imitation":
# This is an experimental feature.
# We may remove it in the future.
self.loss_fn = TrajectoryImitationLoss()
self.task_to_loss["trajectory_imitation"] = self.loss_fn
self.pipe_teacher = copy.deepcopy(self.pipe)
self.pipe_teacher.requires_grad_(False)
def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"]}
@@ -67,6 +74,9 @@ class ZImageTrainingModule(DiffusionTrainingModule):
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
if self.task == "trajectory_imitation":
inputs_shared["cfg_scale"] = 2
inputs_shared["teacher"] = self.pipe_teacher
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
@@ -138,5 +148,6 @@ if __name__ == "__main__":
"sft:train": launch_training_task,
"direct_distill": launch_training_task,
"direct_distill:train": launch_training_task,
"trajectory_imitation": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)