diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 3177191..8e94aee 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -250,6 +250,17 @@ def add_general_parsers(parser): default=None, help="Pretrained LoRA path. Required if the training is resumed.", ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) return parser @@ -270,6 +281,20 @@ def launch_training_task(model, args): num_workers=args.dataloader_num_workers ) # train + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project="diffsynth_studio", + name="diffsynth_studio", + config=swanlab_config, + mode=args.swanlab_mode, + logdir=args.output_path, + ) + logger = [swanlab_logger] + else: + logger = None trainer = pl.Trainer( max_epochs=args.max_epochs, accelerator="gpu", @@ -278,7 +303,8 @@ def launch_training_task(model, args): strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=logger, ) trainer.fit(model=model, train_dataloaders=train_loader) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 4972c26..51ceb3f 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -132,8 +132,8 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ --steps_per_epoch 500 \ --max_epochs 10 \ --learning_rate 1e-4 \ - --lora_rank 4 \ - --lora_alpha 4 \ + --lora_rank 16 \ + --lora_alpha 16 \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --accumulate_grad_batches 1 \ --use_gradient_checkpointing diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index e16aaa0..45fbe90 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -423,6 +423,17 @@ def parse_args(): default=None, help="Pretrained LoRA path. Required if the training is resumed.", ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) args = parser.parse_args() return args @@ -481,6 +492,20 @@ def train(args): use_gradient_checkpointing=args.use_gradient_checkpointing, pretrained_lora_path=args.pretrained_lora_path, ) + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project="wan", + name="wan", + config=swanlab_config, + mode=args.swanlab_mode, + logdir=args.output_path, + ) + logger = [swanlab_logger] + else: + logger = None trainer = pl.Trainer( max_epochs=args.max_epochs, accelerator="gpu", @@ -489,6 +514,7 @@ def train(args): default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=logger, ) trainer.fit(model, dataloader)