mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-23 14:26:16 +00:00
Reorder optimizer and logger calls in training loop (#1404)
This commit is contained in:
@@ -33,15 +33,15 @@ def launch_training_task(
|
|||||||
for epoch_id in range(num_epochs):
|
for epoch_id in range(num_epochs):
|
||||||
for data in tqdm(dataloader):
|
for data in tqdm(dataloader):
|
||||||
with accelerator.accumulate(model):
|
with accelerator.accumulate(model):
|
||||||
optimizer.zero_grad()
|
|
||||||
if dataset.load_from_cache:
|
if dataset.load_from_cache:
|
||||||
loss = model({}, inputs=data)
|
loss = model({}, inputs=data)
|
||||||
else:
|
else:
|
||||||
loss = model(data)
|
loss = model(data)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||||
if save_steps is None:
|
if save_steps is None:
|
||||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||||
model_logger.on_training_end(accelerator, model, save_steps)
|
model_logger.on_training_end(accelerator, model, save_steps)
|
||||||
|
|||||||
Reference in New Issue
Block a user