update swanlab log

This commit is contained in:
Artiprocher
2025-02-27 19:21:51 +08:00
parent 1419bec53d
commit b5c1d33e58
2 changed files with 14 additions and 46 deletions

View File

@@ -256,29 +256,11 @@ def add_general_parsers(parser):
action="store_true", action="store_true",
help="Whether to use SwanLab logger.", help="Whether to use SwanLab logger.",
) )
parser.add_argument(
"--swanlab_project",
type=str,
default="diffsynth_studio",
help="SwanLab project name.",
)
parser.add_argument(
"--swanlab_name",
type=str,
default="diffsynth_studio_train",
help="SwanLab experimentname.",
)
parser.add_argument( parser.add_argument(
"--swanlab_mode", "--swanlab_mode",
default=None, default=None,
help="SwanLab mode (cloud or local).", help="SwanLab mode (cloud or local).",
) )
parser.add_argument(
"--swanlab_logdir",
type=str,
default=None,
help="SwanLab local log directory.",
)
return parser return parser
@@ -301,18 +283,20 @@ def launch_training_task(model, args):
) )
# set swanlab logger # set swanlab logger
swanlab_logger = None
if args.use_swanlab: if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args)) swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger( swanlab_logger = SwanLabLogger(
project=args.swanlab_project, project="diffsynth_studio",
name=args.swanlab_name, name="diffsynth_studio",
config=swanlab_config, config=swanlab_config,
mode=args.swanlab_mode, mode=args.swanlab_mode,
logdir=args.swanlab_logdir, logdir=args.output_path,
) )
logger = [swanlab_config]
else:
logger = []
# train # train
trainer = pl.Trainer( trainer = pl.Trainer(
@@ -324,7 +308,7 @@ def launch_training_task(model, args):
default_root_dir=args.output_path, default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches, accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=[swanlab_logger], logger=logger,
) )
trainer.fit(model=model, train_dataloaders=train_loader) trainer.fit(model=model, train_dataloaders=train_loader)

View File

@@ -400,29 +400,11 @@ def parse_args():
action="store_true", action="store_true",
help="Whether to use SwanLab logger.", help="Whether to use SwanLab logger.",
) )
parser.add_argument(
"--swanlab_project",
type=str,
default="wan_t2v",
help="SwanLab project name.",
)
parser.add_argument(
"--swanlab_name",
type=str,
default="wan_t2v_train",
help="SwanLab experimentname.",
)
parser.add_argument( parser.add_argument(
"--swanlab_mode", "--swanlab_mode",
default=None, default=None,
help="SwanLab mode (cloud or local).", help="SwanLab mode (cloud or local).",
) )
parser.add_argument(
"--swanlab_logdir",
type=str,
default=None,
help="SwanLab local log directory.",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -493,18 +475,20 @@ def train(args):
init_lora_weights=args.init_lora_weights, init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing use_gradient_checkpointing=args.use_gradient_checkpointing
) )
swanlab_logger = None
if args.use_swanlab: if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args)) swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger( swanlab_logger = SwanLabLogger(
project=args.swanlab_project, project="wan",
name=args.swanlab_name, name="wan",
config=swanlab_config, config=swanlab_config,
mode=args.swanlab_mode, mode=args.swanlab_mode,
logdir=args.swanlab_logdir, logdir=args.output_path,
) )
logger = [swanlab_logger]
else:
logger = []
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
accelerator="gpu", accelerator="gpu",
@@ -513,7 +497,7 @@ def train(args):
default_root_dir=args.output_path, default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches, accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=[swanlab_logger], logger=logger,
) )
trainer.fit(model, dataloader) trainer.fit(model, dataloader)