From 1419bec53dd8d0bb5eb4ed3463a535e051646d90 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Wed, 26 Feb 2025 17:12:54 +0800 Subject: [PATCH 1/3] feat: add swanlab logger --- diffsynth/trainers/text_to_image.py | 47 +++++++++++++++++++++++- examples/wanvideo/train_wan_t2v.py | 57 ++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index e2cab59..d811a40 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -250,6 +250,36 @@ def add_general_parsers(parser): default=None, help="Pretrained LoRA path. Required if the training is resumed.", ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_project", + type=str, + default="diffsynth_studio", + help="SwanLab project name.", + ) + parser.add_argument( + "--swanlab_name", + type=str, + default="diffsynth_studio_train", + help="SwanLab experimentname.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) + parser.add_argument( + "--swanlab_logdir", + type=str, + default=None, + help="SwanLab local log directory.", + ) + return parser @@ -270,6 +300,20 @@ def launch_training_task(model, args): num_workers=args.dataloader_num_workers ) + # set swanlab logger + swanlab_logger = None + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project=args.swanlab_project, + name=args.swanlab_name, + config=swanlab_config, + mode=args.swanlab_mode, + logdir=args.swanlab_logdir, + ) + # train trainer = pl.Trainer( max_epochs=args.max_epochs, @@ -279,7 +323,8 @@ def launch_training_task(model, args): strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=[swanlab_logger], ) trainer.fit(model=model, train_dataloaders=train_loader) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 53de964..c26db7a 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -394,6 +394,35 @@ def parse_args(): choices=["lora", "full"], help="Model structure to train. LoRA training or full training.", ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_project", + type=str, + default="wan_t2v", + help="SwanLab project name.", + ) + parser.add_argument( + "--swanlab_name", + type=str, + default="wan_t2v_train", + help="SwanLab experimentname.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) + parser.add_argument( + "--swanlab_logdir", + type=str, + default=None, + help="SwanLab local log directory.", + ) args = parser.parse_args() return args @@ -421,10 +450,23 @@ def data_process(args): tile_size=(args.tile_size_height, args.tile_size_width), tile_stride=(args.tile_stride_height, args.tile_stride_width), ) + swanlab_logger = None + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project=args.swanlab_project, + name=args.swanlab_name, + config=swanlab_config, + mode=args.swanlab_mode, + logdir=args.swanlab_logdir, + ) trainer = pl.Trainer( accelerator="gpu", devices="auto", default_root_dir=args.output_path, + logger=[swanlab_logger], ) trainer.test(model, dataloader) @@ -451,6 +493,18 @@ def train(args): init_lora_weights=args.init_lora_weights, use_gradient_checkpointing=args.use_gradient_checkpointing ) + swanlab_logger = None + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project=args.swanlab_project, + name=args.swanlab_name, + config=swanlab_config, + mode=args.swanlab_mode, + logdir=args.swanlab_logdir, + ) trainer = pl.Trainer( max_epochs=args.max_epochs, accelerator="gpu", @@ -458,7 +512,8 @@ def train(args): strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=[swanlab_logger], ) trainer.fit(model, dataloader) From b5c1d33e5807e915e2aac501e91f18f8e318d006 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 27 Feb 2025 19:21:51 +0800 Subject: [PATCH 2/3] update swanlab log --- diffsynth/trainers/text_to_image.py | 30 +++++++---------------------- examples/wanvideo/train_wan_t2v.py | 30 +++++++---------------------- 2 files changed, 14 insertions(+), 46 deletions(-) diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index d811a40..d55d792 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -256,29 +256,11 @@ def add_general_parsers(parser): action="store_true", help="Whether to use SwanLab logger.", ) - parser.add_argument( - "--swanlab_project", - type=str, - default="diffsynth_studio", - help="SwanLab project name.", - ) - parser.add_argument( - "--swanlab_name", - type=str, - default="diffsynth_studio_train", - help="SwanLab experimentname.", - ) parser.add_argument( "--swanlab_mode", default=None, help="SwanLab mode (cloud or local).", ) - parser.add_argument( - "--swanlab_logdir", - type=str, - default=None, - help="SwanLab local log directory.", - ) return parser @@ -301,18 +283,20 @@ def launch_training_task(model, args): ) # set swanlab logger - swanlab_logger = None if args.use_swanlab: from swanlab.integration.pytorch_lightning import SwanLabLogger swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} swanlab_config.update(vars(args)) swanlab_logger = SwanLabLogger( - project=args.swanlab_project, - name=args.swanlab_name, + project="diffsynth_studio", + name="diffsynth_studio", config=swanlab_config, mode=args.swanlab_mode, - logdir=args.swanlab_logdir, + logdir=args.output_path, ) + logger = [swanlab_config] + else: + logger = [] # train trainer = pl.Trainer( @@ -324,7 +308,7 @@ def launch_training_task(model, args): default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], - logger=[swanlab_logger], + logger=logger, ) trainer.fit(model=model, train_dataloaders=train_loader) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index c26db7a..47ee8ef 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -400,29 +400,11 @@ def parse_args(): action="store_true", help="Whether to use SwanLab logger.", ) - parser.add_argument( - "--swanlab_project", - type=str, - default="wan_t2v", - help="SwanLab project name.", - ) - parser.add_argument( - "--swanlab_name", - type=str, - default="wan_t2v_train", - help="SwanLab experimentname.", - ) parser.add_argument( "--swanlab_mode", default=None, help="SwanLab mode (cloud or local).", ) - parser.add_argument( - "--swanlab_logdir", - type=str, - default=None, - help="SwanLab local log directory.", - ) args = parser.parse_args() return args @@ -493,18 +475,20 @@ def train(args): init_lora_weights=args.init_lora_weights, use_gradient_checkpointing=args.use_gradient_checkpointing ) - swanlab_logger = None if args.use_swanlab: from swanlab.integration.pytorch_lightning import SwanLabLogger swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} swanlab_config.update(vars(args)) swanlab_logger = SwanLabLogger( - project=args.swanlab_project, - name=args.swanlab_name, + project="wan", + name="wan", config=swanlab_config, mode=args.swanlab_mode, - logdir=args.swanlab_logdir, + logdir=args.output_path, ) + logger = [swanlab_logger] + else: + logger = [] trainer = pl.Trainer( max_epochs=args.max_epochs, accelerator="gpu", @@ -513,7 +497,7 @@ def train(args): default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], - logger=[swanlab_logger], + logger=logger, ) trainer.fit(model, dataloader) From a57749ef6073f56563e88b0df88d7b4b19de5de8 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 27 Feb 2025 19:30:53 +0800 Subject: [PATCH 3/3] update swanlab log --- examples/wanvideo/train_wan_t2v.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 47ee8ef..7eb0afa 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -432,23 +432,10 @@ def data_process(args): tile_size=(args.tile_size_height, args.tile_size_width), tile_stride=(args.tile_stride_height, args.tile_stride_width), ) - swanlab_logger = None - if args.use_swanlab: - from swanlab.integration.pytorch_lightning import SwanLabLogger - swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} - swanlab_config.update(vars(args)) - swanlab_logger = SwanLabLogger( - project=args.swanlab_project, - name=args.swanlab_name, - config=swanlab_config, - mode=args.swanlab_mode, - logdir=args.swanlab_logdir, - ) trainer = pl.Trainer( accelerator="gpu", devices="auto", default_root_dir=args.output_path, - logger=[swanlab_logger], ) trainer.test(model, dataloader)