diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index e2cab59..d811a40 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -250,6 +250,36 @@ def add_general_parsers(parser): default=None, help="Pretrained LoRA path. Required if the training is resumed.", ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_project", + type=str, + default="diffsynth_studio", + help="SwanLab project name.", + ) + parser.add_argument( + "--swanlab_name", + type=str, + default="diffsynth_studio_train", + help="SwanLab experimentname.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) + parser.add_argument( + "--swanlab_logdir", + type=str, + default=None, + help="SwanLab local log directory.", + ) + return parser @@ -270,6 +300,20 @@ def launch_training_task(model, args): num_workers=args.dataloader_num_workers ) + # set swanlab logger + swanlab_logger = None + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project=args.swanlab_project, + name=args.swanlab_name, + config=swanlab_config, + mode=args.swanlab_mode, + logdir=args.swanlab_logdir, + ) + # train trainer = pl.Trainer( max_epochs=args.max_epochs, @@ -279,7 +323,8 @@ def launch_training_task(model, args): strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=[swanlab_logger], ) trainer.fit(model=model, train_dataloaders=train_loader) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 53de964..c26db7a 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -394,6 +394,35 @@ def parse_args(): choices=["lora", "full"], help="Model structure to train. LoRA training or full training.", ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_project", + type=str, + default="wan_t2v", + help="SwanLab project name.", + ) + parser.add_argument( + "--swanlab_name", + type=str, + default="wan_t2v_train", + help="SwanLab experimentname.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) + parser.add_argument( + "--swanlab_logdir", + type=str, + default=None, + help="SwanLab local log directory.", + ) args = parser.parse_args() return args @@ -421,10 +450,23 @@ def data_process(args): tile_size=(args.tile_size_height, args.tile_size_width), tile_stride=(args.tile_stride_height, args.tile_stride_width), ) + swanlab_logger = None + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project=args.swanlab_project, + name=args.swanlab_name, + config=swanlab_config, + mode=args.swanlab_mode, + logdir=args.swanlab_logdir, + ) trainer = pl.Trainer( accelerator="gpu", devices="auto", default_root_dir=args.output_path, + logger=[swanlab_logger], ) trainer.test(model, dataloader) @@ -451,6 +493,18 @@ def train(args): init_lora_weights=args.init_lora_weights, use_gradient_checkpointing=args.use_gradient_checkpointing ) + swanlab_logger = None + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project=args.swanlab_project, + name=args.swanlab_name, + config=swanlab_config, + mode=args.swanlab_mode, + logdir=args.swanlab_logdir, + ) trainer = pl.Trainer( max_epochs=args.max_epochs, accelerator="gpu", @@ -458,7 +512,8 @@ def train(args): strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=[swanlab_logger], ) trainer.fit(model, dataloader)