From b5c1d33e5807e915e2aac501e91f18f8e318d006 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 27 Feb 2025 19:21:51 +0800 Subject: [PATCH] update swanlab log --- diffsynth/trainers/text_to_image.py | 30 +++++++---------------------- examples/wanvideo/train_wan_t2v.py | 30 +++++++---------------------- 2 files changed, 14 insertions(+), 46 deletions(-) diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index d811a40..d55d792 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -256,29 +256,11 @@ def add_general_parsers(parser): action="store_true", help="Whether to use SwanLab logger.", ) - parser.add_argument( - "--swanlab_project", - type=str, - default="diffsynth_studio", - help="SwanLab project name.", - ) - parser.add_argument( - "--swanlab_name", - type=str, - default="diffsynth_studio_train", - help="SwanLab experimentname.", - ) parser.add_argument( "--swanlab_mode", default=None, help="SwanLab mode (cloud or local).", ) - parser.add_argument( - "--swanlab_logdir", - type=str, - default=None, - help="SwanLab local log directory.", - ) return parser @@ -301,18 +283,20 @@ def launch_training_task(model, args): ) # set swanlab logger - swanlab_logger = None if args.use_swanlab: from swanlab.integration.pytorch_lightning import SwanLabLogger swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} swanlab_config.update(vars(args)) swanlab_logger = SwanLabLogger( - project=args.swanlab_project, - name=args.swanlab_name, + project="diffsynth_studio", + name="diffsynth_studio", config=swanlab_config, mode=args.swanlab_mode, - logdir=args.swanlab_logdir, + logdir=args.output_path, ) + logger = [swanlab_config] + else: + logger = [] # train trainer = pl.Trainer( @@ -324,7 +308,7 @@ def launch_training_task(model, args): default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], - logger=[swanlab_logger], + logger=logger, ) trainer.fit(model=model, train_dataloaders=train_loader) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index c26db7a..47ee8ef 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -400,29 +400,11 @@ def parse_args(): action="store_true", help="Whether to use SwanLab logger.", ) - parser.add_argument( - "--swanlab_project", - type=str, - default="wan_t2v", - help="SwanLab project name.", - ) - parser.add_argument( - "--swanlab_name", - type=str, - default="wan_t2v_train", - help="SwanLab experimentname.", - ) parser.add_argument( "--swanlab_mode", default=None, help="SwanLab mode (cloud or local).", ) - parser.add_argument( - "--swanlab_logdir", - type=str, - default=None, - help="SwanLab local log directory.", - ) args = parser.parse_args() return args @@ -493,18 +475,20 @@ def train(args): init_lora_weights=args.init_lora_weights, use_gradient_checkpointing=args.use_gradient_checkpointing ) - swanlab_logger = None if args.use_swanlab: from swanlab.integration.pytorch_lightning import SwanLabLogger swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} swanlab_config.update(vars(args)) swanlab_logger = SwanLabLogger( - project=args.swanlab_project, - name=args.swanlab_name, + project="wan", + name="wan", config=swanlab_config, mode=args.swanlab_mode, - logdir=args.swanlab_logdir, + logdir=args.output_path, ) + logger = [swanlab_logger] + else: + logger = [] trainer = pl.Trainer( max_epochs=args.max_epochs, accelerator="gpu", @@ -513,7 +497,7 @@ def train(args): default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], - logger=[swanlab_logger], + logger=logger, ) trainer.fit(model, dataloader)