refactor model_logger

This commit is contained in:
mi804
2025-08-06 15:47:35 +08:00
parent 3915bc3ee6
commit ef09db69cd

View File

@@ -365,23 +365,39 @@ class ModelLogger:
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter
def on_step_end(self, loss):
pass
self.num_steps = 0
def on_model_save(self, accelerator, model, step_id=None, epoch_id=None):
def on_step_end(self, accelerator, model, save_steps=None):
self.num_steps += 1
if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def on_epoch_end(self, accelerator, model, epoch_id):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
if step_id is not None:
path = os.path.join(self.output_path, f"step-{step_id}.safetensors")
else:
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
def on_training_end(self, accelerator, model, save_steps=None):
if save_steps is not None and self.num_steps % save_steps != 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def save_model(self, accelerator, model, file_name):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, file_name)
accelerator.save(state_dict, path, safe_serialization=True)
@@ -404,26 +420,18 @@ def launch_training_task(
)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
global_steps = 0
for epoch_id in range(num_epochs):
for step_id, data in enumerate(tqdm(dataloader)):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(loss)
model_logger.on_step_end(accelerator, model, save_steps)
scheduler.step()
global_steps = epoch_id * len(dataloader) + step_id + 1
# save every `save_steps` steps
if save_steps is not None and global_steps % save_steps == 0:
model_logger.on_model_save(accelerator, model, step_id=global_steps)
# save the model at the end of each epoch if save_steps is None
if save_steps is None:
model_logger.on_model_save(accelerator, model, epoch_id=epoch_id)
# save the final model if save_steps is not None
if save_steps is not None:
model_logger.on_model_save(accelerator, model, step_id=global_steps)
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):