mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
refactor model_logger
This commit is contained in:
@@ -365,23 +365,39 @@ class ModelLogger:
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
self.state_dict_converter = state_dict_converter
|
||||
|
||||
|
||||
def on_step_end(self, loss):
|
||||
pass
|
||||
|
||||
self.num_steps = 0
|
||||
|
||||
def on_model_save(self, accelerator, model, step_id=None, epoch_id=None):
|
||||
|
||||
def on_step_end(self, accelerator, model, save_steps=None):
|
||||
self.num_steps += 1
|
||||
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
if step_id is not None:
|
||||
path = os.path.join(self.output_path, f"step-{step_id}.safetensors")
|
||||
else:
|
||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
def on_training_end(self, accelerator, model, save_steps=None):
|
||||
if save_steps is not None and self.num_steps % save_steps != 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
|
||||
def save_model(self, accelerator, model, file_name):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||
state_dict = self.state_dict_converter(state_dict)
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, file_name)
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
@@ -404,26 +420,18 @@ def launch_training_task(
|
||||
)
|
||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||
|
||||
global_steps = 0
|
||||
for epoch_id in range(num_epochs):
|
||||
for step_id, data in enumerate(tqdm(dataloader)):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(loss)
|
||||
model_logger.on_step_end(accelerator, model, save_steps)
|
||||
scheduler.step()
|
||||
global_steps = epoch_id * len(dataloader) + step_id + 1
|
||||
# save every `save_steps` steps
|
||||
if save_steps is not None and global_steps % save_steps == 0:
|
||||
model_logger.on_model_save(accelerator, model, step_id=global_steps)
|
||||
# save the model at the end of each epoch if save_steps is None
|
||||
if save_steps is None:
|
||||
model_logger.on_model_save(accelerator, model, epoch_id=epoch_id)
|
||||
# save the final model if save_steps is not None
|
||||
if save_steps is not None:
|
||||
model_logger.on_model_save(accelerator, model, step_id=global_steps)
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
model_logger.on_training_end(accelerator, model, save_steps)
|
||||
|
||||
|
||||
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
||||
|
||||
Reference in New Issue
Block a user