mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-23 14:26:16 +00:00
Reorder optimizer and logger calls in training loop (#1404)
This commit is contained in:
@@ -33,15 +33,15 @@ def launch_training_task(
|
||||
for epoch_id in range(num_epochs):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
if dataset.load_from_cache:
|
||||
loss = model({}, inputs=data)
|
||||
else:
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||
if save_steps is None:
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
model_logger.on_training_end(accelerator, model, save_steps)
|
||||
|
||||
Reference in New Issue
Block a user