This commit is contained in:
mi804
2025-09-22 17:45:42 +08:00
parent bf7b339efb
commit d96709fb6a
4 changed files with 88 additions and 191 deletions

View File

@@ -1,5 +1,4 @@
import torch, os, json
from diffsynth import load_state_dict
import torch, os
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.pipelines.flux_image_new import ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
@@ -7,7 +6,6 @@ from diffsynth.trainers.unified_dataset import UnifiedDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class QwenImageTrainingModule(DiffusionTrainingModule):
def __init__(
self,
@@ -20,6 +18,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
extra_inputs=None,
enable_fp8_training=False,
task="sft",
beta_dpo=1000.,
):
super().__init__()
# Load models
@@ -40,8 +39,9 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.task = task
self.lora_base_model = lora_base_model
self.beta_dpo = beta_dpo
def forward_preprocess(self, data):
# CFG-sensitive parameters
inputs_posi = {"prompt": data["prompt"]}
@@ -81,9 +81,44 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None, return_inputs=False, **kwargs):
def forward_dpo(self, data, accelerator=None):
# Loss DPO: -logσ(−β(diff_policy diff_ref))
# Prepare inputs
win_data = {key: data[key] for key in ["prompt", "image"]}
lose_data = {"prompt": None, "image": data["lose_image"]}
inputs_win = self.forward_preprocess(win_data)
inputs_lose = self.forward_preprocess(lose_data)
inputs_lose.update({key: inputs_win[key] for key in ["prompt", "prompt_emb", "prompt_emb_mask"]})
inputs_win.pop('noise')
inputs_lose.pop('noise')
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
# sample timestep and noise
timestep = self.pipe.sample_timestep()
noise = torch.rand_like(inputs_win["latents"])
# compute diff_policy = loss_win - loss_lose
loss_win = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win)
loss_lose = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose)
diff_policy = loss_win - loss_lose
# compute diff_ref
if self.lora_base_model is not None:
self.disable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit)
# load the original model weights
with torch.no_grad():
loss_win_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win)
loss_lose_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose)
diff_ref = loss_win_ref - loss_lose_ref
self.enable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit)
else:
# TODO: may support full model training
raise NotImplementedError("DPO with full model training is not supported yet.")
# compute loss
loss = -1. * torch.nn.functional.logsigmoid(self.beta_dpo * (diff_ref - diff_policy)).mean()
return loss
def forward(self, data, inputs=None, return_inputs=False, accelerator=None, **kwargs):
if self.task == "dpo":
return self.forward_dpo(data, accelerator=accelerator)
# Inputs
if inputs is None:
inputs = self.forward_preprocess(data)
@@ -137,11 +172,13 @@ if __name__ == "__main__":
extra_inputs=args.extra_inputs,
enable_fp8_training=args.enable_fp8_training,
task=args.task,
beta_dpo=args.beta_dpo,
)
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
launcher_map = {
"sft": launch_training_task,
"data_process": launch_data_process_task,
"direct_distill": launch_training_task,
"dpo": launch_training_task,
}
launcher_map[args.task](dataset, model, model_logger, args=args)