mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
z image distill
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import torch, os, argparse, accelerate
|
||||
import torch, os, argparse, accelerate, copy
|
||||
from diffsynth.core import UnifiedDataset
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
from diffsynth.diffusion import *
|
||||
@@ -50,6 +50,13 @@ class ZImageTrainingModule(DiffusionTrainingModule):
|
||||
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
}
|
||||
if task == "trajectory_imitation":
|
||||
# This is an experimental feature.
|
||||
# We may remove it in the future.
|
||||
self.loss_fn = TrajectoryImitationLoss()
|
||||
self.task_to_loss["trajectory_imitation"] = self.loss_fn
|
||||
self.pipe_teacher = copy.deepcopy(self.pipe)
|
||||
self.pipe_teacher.requires_grad_(False)
|
||||
|
||||
def get_pipeline_inputs(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
@@ -67,6 +74,9 @@ class ZImageTrainingModule(DiffusionTrainingModule):
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
}
|
||||
if self.task == "trajectory_imitation":
|
||||
inputs_shared["cfg_scale"] = 2
|
||||
inputs_shared["teacher"] = self.pipe_teacher
|
||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
@@ -138,5 +148,6 @@ if __name__ == "__main__":
|
||||
"sft:train": launch_training_task,
|
||||
"direct_distill": launch_training_task,
|
||||
"direct_distill:train": launch_training_task,
|
||||
"trajectory_imitation": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||
|
||||
Reference in New Issue
Block a user