diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index e2cab59..d55d792 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -250,6 +250,18 @@ def add_general_parsers(parser): default=None, help="Pretrained LoRA path. Required if the training is resumed.", ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) + return parser @@ -270,6 +282,22 @@ def launch_training_task(model, args): num_workers=args.dataloader_num_workers ) + # set swanlab logger + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project="diffsynth_studio", + name="diffsynth_studio", + config=swanlab_config, + mode=args.swanlab_mode, + logdir=args.output_path, + ) + logger = [swanlab_config] + else: + logger = [] + # train trainer = pl.Trainer( max_epochs=args.max_epochs, @@ -279,7 +307,8 @@ def launch_training_task(model, args): strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=logger, ) trainer.fit(model=model, train_dataloaders=train_loader) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 817fd5c..4e7f8bc 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -405,6 +405,17 @@ def parse_args(): choices=["lora", "full"], help="Model structure to train. LoRA training or full training.", ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) args = parser.parse_args() return args @@ -462,6 +473,20 @@ def train(args): init_lora_weights=args.init_lora_weights, use_gradient_checkpointing=args.use_gradient_checkpointing ) + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project="wan", + name="wan", + config=swanlab_config, + mode=args.swanlab_mode, + logdir=args.output_path, + ) + logger = [swanlab_logger] + else: + logger = [] trainer = pl.Trainer( max_epochs=args.max_epochs, accelerator="gpu", @@ -469,7 +494,8 @@ def train(args): strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=logger, ) trainer.fit(model, dataloader)