Reorder optimizer and logger calls in training loop (#1404)

This commit is contained in:
Qifan Zhang
2026-04-21 13:45:09 +08:00
committed by GitHub
parent 079e51c9f3
commit 5c89a15b9a

View File

@@ -33,15 +33,15 @@ def launch_training_task(
for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()
optimizer.zero_grad()
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)