mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
1 Commits
cache_lear
...
dpo-refine
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6737dbfc9f |
@@ -475,6 +475,64 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
if len(load_result[1]) > 0:
|
if len(load_result[1]) > 0:
|
||||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||||
setattr(pipe, lora_base_model, model)
|
setattr(pipe, lora_base_model, model)
|
||||||
|
|
||||||
|
def disable_all_lora_layers(self, model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if hasattr(module, 'enable_adapters'):
|
||||||
|
module.enable_adapters(False)
|
||||||
|
|
||||||
|
def enable_all_lora_layers(self, model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if hasattr(module, 'enable_adapters'):
|
||||||
|
module.enable_adapters(True)
|
||||||
|
|
||||||
|
|
||||||
|
class DPOLoss:
|
||||||
|
def __init__(self, beta=2500):
|
||||||
|
self.beta = beta
|
||||||
|
|
||||||
|
def sample_timestep(self, pipe):
|
||||||
|
timestep_id = torch.randint(0, pipe.scheduler.num_train_timesteps, (1,))
|
||||||
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return timestep
|
||||||
|
|
||||||
|
def training_loss_minimum(self, pipe, noise, timestep, **inputs):
|
||||||
|
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
noise_pred = pipe.model_fn(**inputs, timestep=timestep)
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def loss(self, model, data):
|
||||||
|
# Loss DPO: -logσ(−β(diff_policy − diff_ref))
|
||||||
|
# Prepare inputs
|
||||||
|
win_data = {key: data[key] for key in ["prompt", "image"]}
|
||||||
|
lose_data = {"prompt": data["prompt"], "image": data["lose_image"]}
|
||||||
|
inputs_win = model.forward_preprocess(win_data)
|
||||||
|
inputs_lose = model.forward_preprocess(lose_data)
|
||||||
|
inputs_win.pop('noise')
|
||||||
|
inputs_lose.pop('noise')
|
||||||
|
models = {name: getattr(model.pipe, name) for name in model.pipe.in_iteration_models}
|
||||||
|
# sample timestep and noise
|
||||||
|
timestep = self.sample_timestep(model.pipe)
|
||||||
|
noise = torch.rand_like(inputs_win["latents"])
|
||||||
|
# compute diff_policy = loss_win - loss_lose
|
||||||
|
loss_win = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win)
|
||||||
|
loss_lose = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose)
|
||||||
|
diff_policy = loss_win - loss_lose
|
||||||
|
# compute diff_ref
|
||||||
|
# TODO: may support full model training
|
||||||
|
model.disable_all_lora_layers(model.pipe.dit)
|
||||||
|
# load the original model weights
|
||||||
|
with torch.no_grad():
|
||||||
|
loss_win_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_win)
|
||||||
|
loss_lose_ref = self.training_loss_minimum(model.pipe, noise, timestep, **models, **inputs_lose)
|
||||||
|
diff_ref = loss_win_ref - loss_lose_ref
|
||||||
|
model.enable_all_lora_layers(model.pipe.dit)
|
||||||
|
# compute loss
|
||||||
|
loss = -1. * torch.nn.functional.logsigmoid(self.beta * (diff_ref - diff_policy)).mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class ModelLogger:
|
class ModelLogger:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import torch, os, json
|
|||||||
from diffsynth import load_state_dict
|
from diffsynth import load_state_dict
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task, DPOLoss
|
||||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
@@ -84,24 +84,29 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, data, inputs=None, return_inputs=False):
|
def forward(self, data, inputs=None, return_inputs=False):
|
||||||
# Inputs
|
# DPO (DPO requires a special training loss)
|
||||||
if inputs is None:
|
if self.task == "dpo":
|
||||||
inputs = self.forward_preprocess(data)
|
loss = DPOLoss().loss(self, data)
|
||||||
|
return loss
|
||||||
else:
|
else:
|
||||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
# Inputs
|
||||||
if return_inputs: return inputs
|
if inputs is None:
|
||||||
|
inputs = self.forward_preprocess(data)
|
||||||
# Loss
|
else:
|
||||||
if self.task == "sft":
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
if return_inputs: return inputs
|
||||||
loss = self.pipe.training_loss(**models, **inputs)
|
|
||||||
elif self.task == "data_process":
|
# Loss
|
||||||
loss = inputs
|
if self.task == "sft":
|
||||||
elif self.task == "direct_distill":
|
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||||
loss = self.pipe.direct_distill_loss(**inputs)
|
loss = self.pipe.training_loss(**models, **inputs)
|
||||||
else:
|
elif self.task == "data_process":
|
||||||
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
loss = inputs
|
||||||
return loss
|
elif self.task == "direct_distill":
|
||||||
|
loss = self.pipe.direct_distill_loss(**inputs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported task: {self.task}.")
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -143,5 +148,6 @@ if __name__ == "__main__":
|
|||||||
"sft": launch_training_task,
|
"sft": launch_training_task,
|
||||||
"data_process": launch_data_process_task,
|
"data_process": launch_data_process_task,
|
||||||
"direct_distill": launch_training_task,
|
"direct_distill": launch_training_task,
|
||||||
|
"dpo": launch_training_task,
|
||||||
}
|
}
|
||||||
launcher_map[args.task](dataset, model, model_logger, args=args)
|
launcher_map[args.task](dataset, model, model_logger, args=args)
|
||||||
|
|||||||
Reference in New Issue
Block a user