feat: add swanlab logger

This commit is contained in:
ZeYi Lin
2025-02-26 17:12:54 +08:00
parent cf12723c89
commit 1419bec53d
2 changed files with 102 additions and 2 deletions

View File

@@ -250,6 +250,36 @@ def add_general_parsers(parser):
default=None,
help="Pretrained LoRA path. Required if the training is resumed.",
)
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_project",
type=str,
default="diffsynth_studio",
help="SwanLab project name.",
)
parser.add_argument(
"--swanlab_name",
type=str,
default="diffsynth_studio_train",
help="SwanLab experimentname.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
parser.add_argument(
"--swanlab_logdir",
type=str,
default=None,
help="SwanLab local log directory.",
)
return parser
@@ -270,6 +300,20 @@ def launch_training_task(model, args):
num_workers=args.dataloader_num_workers
)
# set swanlab logger
swanlab_logger = None
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project=args.swanlab_project,
name=args.swanlab_name,
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.swanlab_logdir,
)
# train
trainer = pl.Trainer(
max_epochs=args.max_epochs,
@@ -279,7 +323,8 @@ def launch_training_task(model, args):
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=[swanlab_logger],
)
trainer.fit(model=model, train_dataloaders=train_loader)

View File

@@ -394,6 +394,35 @@ def parse_args():
choices=["lora", "full"],
help="Model structure to train. LoRA training or full training.",
)
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_project",
type=str,
default="wan_t2v",
help="SwanLab project name.",
)
parser.add_argument(
"--swanlab_name",
type=str,
default="wan_t2v_train",
help="SwanLab experimentname.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
parser.add_argument(
"--swanlab_logdir",
type=str,
default=None,
help="SwanLab local log directory.",
)
args = parser.parse_args()
return args
@@ -421,10 +450,23 @@ def data_process(args):
tile_size=(args.tile_size_height, args.tile_size_width),
tile_stride=(args.tile_stride_height, args.tile_stride_width),
)
swanlab_logger = None
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project=args.swanlab_project,
name=args.swanlab_name,
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.swanlab_logdir,
)
trainer = pl.Trainer(
accelerator="gpu",
devices="auto",
default_root_dir=args.output_path,
logger=[swanlab_logger],
)
trainer.test(model, dataloader)
@@ -451,6 +493,18 @@ def train(args):
init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing
)
swanlab_logger = None
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project=args.swanlab_project,
name=args.swanlab_name,
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.swanlab_logdir,
)
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
@@ -458,7 +512,8 @@ def train(args):
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=[swanlab_logger],
)
trainer.fit(model, dataloader)