Merge pull request #749 from mi804/training_args

support num_workers,save_steps,find_unused_parameters
This commit is contained in:
Zhongjie Duan
2025-08-06 15:54:04 +08:00
committed by GitHub
11 changed files with 78 additions and 14 deletions

View File

@@ -30,7 +30,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path))
else:
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
# Reset training scheduler (do it in each training step)
self.pipe.scheduler.set_timesteps(1000, training=True)
@@ -50,7 +50,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
def forward_preprocess(self, data):
# CFG-sensitive parameters
@@ -117,4 +117,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
)