del swanlab

This commit is contained in:
Artiprocher
2025-02-28 16:13:06 +08:00
parent 6fa8dbe077
commit a466fdca8f
2 changed files with 0 additions and 55 deletions

View File

@@ -405,17 +405,6 @@ def parse_args():
choices=["lora", "full"],
help="Model structure to train. LoRA training or full training.",
)
parser.add_argument(
"--use_swanlab",
default=False,
action="store_true",
help="Whether to use SwanLab logger.",
)
parser.add_argument(
"--swanlab_mode",
default=None,
help="SwanLab mode (cloud or local).",
)
args = parser.parse_args()
return args
@@ -473,20 +462,6 @@ def train(args):
init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing
)
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="wan",
name="wan",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.output_path,
)
logger = [swanlab_logger]
else:
logger = []
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
@@ -495,7 +470,6 @@ def train(args):
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
logger=logger,
)
trainer.fit(model, dataloader)