fix swanlab after test

This commit is contained in:
Artiprocher
2025-03-03 18:59:34 +08:00
parent eb4d5187d8
commit 6b67a11ad6
3 changed files with 55 additions and 3 deletions

View File

@@ -250,6 +250,17 @@ def add_general_parsers(parser):
default=None, default=None,
help="Pretrained LoRA path. Required if the training is resumed.", help="Pretrained LoRA path. Required if the training is resumed.",
) )
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
return parser return parser
@@ -270,6 +281,20 @@ def launch_training_task(model, args):
num_workers=args.dataloader_num_workers num_workers=args.dataloader_num_workers
) )
# train # train
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="diffsynth_studio",
name="diffsynth_studio",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.output_path,
)
logger = [swanlab_logger]
else:
logger = None
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
accelerator="gpu", accelerator="gpu",
@@ -278,7 +303,8 @@ def launch_training_task(model, args):
strategy=args.training_strategy, strategy=args.training_strategy,
default_root_dir=args.output_path, default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches, accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=logger,
) )
trainer.fit(model=model, train_dataloaders=train_loader) trainer.fit(model=model, train_dataloaders=train_loader)

View File

@@ -132,8 +132,8 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
--steps_per_epoch 500 \ --steps_per_epoch 500 \
--max_epochs 10 \ --max_epochs 10 \
--learning_rate 1e-4 \ --learning_rate 1e-4 \
--lora_rank 4 \ --lora_rank 16 \
--lora_alpha 4 \ --lora_alpha 16 \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--accumulate_grad_batches 1 \ --accumulate_grad_batches 1 \
--use_gradient_checkpointing --use_gradient_checkpointing

View File

@@ -423,6 +423,17 @@ def parse_args():
default=None, default=None,
help="Pretrained LoRA path. Required if the training is resumed.", help="Pretrained LoRA path. Required if the training is resumed.",
) )
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -481,6 +492,20 @@ def train(args):
use_gradient_checkpointing=args.use_gradient_checkpointing, use_gradient_checkpointing=args.use_gradient_checkpointing,
pretrained_lora_path=args.pretrained_lora_path, pretrained_lora_path=args.pretrained_lora_path,
) )
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="wan",
name="wan",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.output_path,
)
logger = [swanlab_logger]
else:
logger = None
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
accelerator="gpu", accelerator="gpu",
@@ -489,6 +514,7 @@ def train(args):
default_root_dir=args.output_path, default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches, accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=logger,
) )
trainer.fit(model, dataloader) trainer.fit(model, dataloader)