mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:23:43 +00:00
44 lines
2.1 KiB
Python
44 lines
2.1 KiB
Python
import os, torch
|
|
from accelerate import Accelerator
|
|
|
|
|
|
class ModelLogger:
|
|
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
|
self.output_path = output_path
|
|
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
|
self.state_dict_converter = state_dict_converter
|
|
self.num_steps = 0
|
|
|
|
|
|
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
|
self.num_steps += 1
|
|
if save_steps is not None and self.num_steps % save_steps == 0:
|
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
|
|
|
|
|
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
state_dict = accelerator.get_state_dict(model)
|
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
|
state_dict = self.state_dict_converter(state_dict)
|
|
os.makedirs(self.output_path, exist_ok=True)
|
|
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
|
accelerator.save(state_dict, path, safe_serialization=True)
|
|
|
|
|
|
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
|
if save_steps is not None and self.num_steps % save_steps != 0:
|
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
|
|
|
|
|
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
state_dict = accelerator.get_state_dict(model)
|
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
|
state_dict = self.state_dict_converter(state_dict)
|
|
os.makedirs(self.output_path, exist_ok=True)
|
|
path = os.path.join(self.output_path, file_name)
|
|
accelerator.save(state_dict, path, safe_serialization=True)
|