mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
support dpo training
This commit is contained in:
@@ -525,7 +525,7 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
return split_result
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||
if pipe.text_encoder is not None:
|
||||
if pipe.text_encoder is not None and prompt is not None:
|
||||
prompt = [prompt]
|
||||
# If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit
|
||||
if edit_image is None:
|
||||
|
||||
@@ -396,6 +396,15 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
param.data = param.to(upcast_dtype)
|
||||
return model
|
||||
|
||||
def disable_all_lora_layers(self, model):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'enable_adapters'):
|
||||
module.enable_adapters(False)
|
||||
|
||||
def enable_all_lora_layers(self, model):
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'enable_adapters'):
|
||||
module.enable_adapters(True)
|
||||
|
||||
def mapping_lora_state_dict(self, state_dict):
|
||||
new_state_dict = {}
|
||||
@@ -554,9 +563,9 @@ def launch_training_task(
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
if dataset.load_from_cache:
|
||||
loss = model({}, inputs=data)
|
||||
loss = model({}, inputs=data, accelerator=accelerator)
|
||||
else:
|
||||
loss = model(data)
|
||||
loss = model(data, accelerator=accelerator)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(accelerator, model, save_steps)
|
||||
@@ -690,4 +699,5 @@ def qwen_image_parser():
|
||||
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
||||
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
|
||||
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||
parser.add_argument("--beta_dpo", type=float, default=1000, help="hyperparameter beta for DPO loss, only used when task is dpo.")
|
||||
return parser
|
||||
|
||||
@@ -153,6 +153,19 @@ class BasePipeline(torch.nn.Module):
|
||||
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||
return latents_next
|
||||
|
||||
def sample_timestep(self):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
return timestep
|
||||
|
||||
def training_loss_minimum(self, noise, timestep, **inputs):
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user