mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
refactor model_logger
This commit is contained in:
@@ -365,26 +365,42 @@ class ModelLogger:
|
|||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||||
self.state_dict_converter = state_dict_converter
|
self.state_dict_converter = state_dict_converter
|
||||||
|
self.num_steps = 0
|
||||||
|
|
||||||
|
|
||||||
def on_step_end(self, loss):
|
def on_step_end(self, accelerator, model, save_steps=None):
|
||||||
pass
|
self.num_steps += 1
|
||||||
|
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||||
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||||
|
|
||||||
|
|
||||||
def on_model_save(self, accelerator, model, step_id=None, epoch_id=None):
|
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
state_dict = accelerator.get_state_dict(model)
|
state_dict = accelerator.get_state_dict(model)
|
||||||
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
state_dict = self.state_dict_converter(state_dict)
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
os.makedirs(self.output_path, exist_ok=True)
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
if step_id is not None:
|
|
||||||
path = os.path.join(self.output_path, f"step-{step_id}.safetensors")
|
|
||||||
else:
|
|
||||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||||
accelerator.save(state_dict, path, safe_serialization=True)
|
accelerator.save(state_dict, path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
|
def on_training_end(self, accelerator, model, save_steps=None):
|
||||||
|
if save_steps is not None and self.num_steps % save_steps != 0:
|
||||||
|
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(self, accelerator, model, file_name):
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
||||||
|
state_dict = self.state_dict_converter(state_dict)
|
||||||
|
os.makedirs(self.output_path, exist_ok=True)
|
||||||
|
path = os.path.join(self.output_path, file_name)
|
||||||
|
accelerator.save(state_dict, path, safe_serialization=True)
|
||||||
|
|
||||||
|
|
||||||
def launch_training_task(
|
def launch_training_task(
|
||||||
dataset: torch.utils.data.Dataset,
|
dataset: torch.utils.data.Dataset,
|
||||||
model: DiffusionTrainingModule,
|
model: DiffusionTrainingModule,
|
||||||
@@ -404,26 +420,18 @@ def launch_training_task(
|
|||||||
)
|
)
|
||||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||||
|
|
||||||
global_steps = 0
|
|
||||||
for epoch_id in range(num_epochs):
|
for epoch_id in range(num_epochs):
|
||||||
for step_id, data in enumerate(tqdm(dataloader)):
|
for data in tqdm(dataloader):
|
||||||
with accelerator.accumulate(model):
|
with accelerator.accumulate(model):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss = model(data)
|
loss = model(data)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
model_logger.on_step_end(loss)
|
model_logger.on_step_end(accelerator, model, save_steps)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
global_steps = epoch_id * len(dataloader) + step_id + 1
|
|
||||||
# save every `save_steps` steps
|
|
||||||
if save_steps is not None and global_steps % save_steps == 0:
|
|
||||||
model_logger.on_model_save(accelerator, model, step_id=global_steps)
|
|
||||||
# save the model at the end of each epoch if save_steps is None
|
|
||||||
if save_steps is None:
|
if save_steps is None:
|
||||||
model_logger.on_model_save(accelerator, model, epoch_id=epoch_id)
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||||
# save the final model if save_steps is not None
|
model_logger.on_training_end(accelerator, model, save_steps)
|
||||||
if save_steps is not None:
|
|
||||||
model_logger.on_model_save(accelerator, model, step_id=global_steps)
|
|
||||||
|
|
||||||
|
|
||||||
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
|
||||||
|
|||||||
Reference in New Issue
Block a user