diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 7b2690f..00a352e 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -250,18 +250,6 @@ def add_general_parsers(parser): default=None, help="Pretrained LoRA path. Required if the training is resumed.", ) - parser.add_argument( - "--use_swanlab", - default=False, - action="store_true", - help="Whether to use SwanLab logger.", - ) - parser.add_argument( - "--swanlab_mode", - default=None, - help="SwanLab mode (cloud or local).", - ) - return parser @@ -281,23 +269,6 @@ def launch_training_task(model, args): batch_size=args.batch_size, num_workers=args.dataloader_num_workers ) - - # set swanlab logger - if args.use_swanlab: - from swanlab.integration.pytorch_lightning import SwanLabLogger - swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} - swanlab_config.update(vars(args)) - swanlab_logger = SwanLabLogger( - project="diffsynth_studio", - name="diffsynth_studio", - config=swanlab_config, - mode=args.swanlab_mode, - logdir=args.output_path, - ) - logger = [swanlab_logger] - else: - logger = [] - # train trainer = pl.Trainer( max_epochs=args.max_epochs, @@ -308,7 +279,6 @@ def launch_training_task(model, args): default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], - logger=logger, ) trainer.fit(model=model, train_dataloaders=train_loader) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 4e7f8bc..39aa4c0 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -405,17 +405,6 @@ def parse_args(): choices=["lora", "full"], help="Model structure to train. LoRA training or full training.", ) - parser.add_argument( - "--use_swanlab", - default=False, - action="store_true", - help="Whether to use SwanLab logger.", - ) - parser.add_argument( - "--swanlab_mode", - default=None, - help="SwanLab mode (cloud or local).", - ) args = parser.parse_args() return args @@ -473,20 +462,6 @@ def train(args): init_lora_weights=args.init_lora_weights, use_gradient_checkpointing=args.use_gradient_checkpointing ) - if args.use_swanlab: - from swanlab.integration.pytorch_lightning import SwanLabLogger - swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} - swanlab_config.update(vars(args)) - swanlab_logger = SwanLabLogger( - project="wan", - name="wan", - config=swanlab_config, - mode=args.swanlab_mode, - logdir=args.output_path, - ) - logger = [swanlab_logger] - else: - logger = [] trainer = pl.Trainer( max_epochs=args.max_epochs, accelerator="gpu", @@ -495,7 +470,6 @@ def train(args): default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], - logger=logger, ) trainer.fit(model, dataloader)