From ef09db69cd471b614a3e6e6cba247a80c9cb3c39 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 6 Aug 2025 15:47:35 +0800 Subject: [PATCH] refactor model_logger --- diffsynth/trainers/utils.py | 52 +++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index fff84d5..65e4e50 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -365,23 +365,39 @@ class ModelLogger: self.output_path = output_path self.remove_prefix_in_ckpt = remove_prefix_in_ckpt self.state_dict_converter = state_dict_converter - - - def on_step_end(self, loss): - pass - + self.num_steps = 0 - def on_model_save(self, accelerator, model, step_id=None, epoch_id=None): + + def on_step_end(self, accelerator, model, save_steps=None): + self.num_steps += 1 + if save_steps is not None and self.num_steps % save_steps == 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def on_epoch_end(self, accelerator, model, epoch_id): accelerator.wait_for_everyone() if accelerator.is_main_process: state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) - if step_id is not None: - path = os.path.join(self.output_path, f"step-{step_id}.safetensors") - else: - path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + accelerator.save(state_dict, path, safe_serialization=True) + + + def on_training_end(self, accelerator, model, save_steps=None): + if save_steps is not None and self.num_steps % save_steps != 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def save_model(self, accelerator, model, file_name): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, file_name) accelerator.save(state_dict, path, safe_serialization=True) @@ -404,26 +420,18 @@ def launch_training_task( ) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) - global_steps = 0 for epoch_id in range(num_epochs): - for step_id, data in enumerate(tqdm(dataloader)): + for data in tqdm(dataloader): with accelerator.accumulate(model): optimizer.zero_grad() loss = model(data) accelerator.backward(loss) optimizer.step() - model_logger.on_step_end(loss) + model_logger.on_step_end(accelerator, model, save_steps) scheduler.step() - global_steps = epoch_id * len(dataloader) + step_id + 1 - # save every `save_steps` steps - if save_steps is not None and global_steps % save_steps == 0: - model_logger.on_model_save(accelerator, model, step_id=global_steps) - # save the model at the end of each epoch if save_steps is None if save_steps is None: - model_logger.on_model_save(accelerator, model, epoch_id=epoch_id) - # save the final model if save_steps is not None - if save_steps is not None: - model_logger.on_model_save(accelerator, model, step_id=global_steps) + model_logger.on_epoch_end(accelerator, model, epoch_id) + model_logger.on_training_end(accelerator, model, save_steps) def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):