mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
del swanlab
This commit is contained in:
@@ -250,18 +250,6 @@ def add_general_parsers(parser):
|
||||
default=None,
|
||||
help="Pretrained LoRA path. Required if the training is resumed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_swanlab",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use SwanLab logger.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--swanlab_mode",
|
||||
default=None,
|
||||
help="SwanLab mode (cloud or local).",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -282,22 +270,6 @@ def launch_training_task(model, args):
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
# set swanlab logger
|
||||
if args.use_swanlab:
|
||||
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
||||
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
||||
swanlab_config.update(vars(args))
|
||||
swanlab_logger = SwanLabLogger(
|
||||
project="diffsynth_studio",
|
||||
name="diffsynth_studio",
|
||||
config=swanlab_config,
|
||||
mode=args.swanlab_mode,
|
||||
logdir=args.output_path,
|
||||
)
|
||||
logger = [swanlab_config]
|
||||
else:
|
||||
logger = []
|
||||
|
||||
# train
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
@@ -308,7 +280,6 @@ def launch_training_task(model, args):
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||
logger=logger,
|
||||
)
|
||||
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||
|
||||
|
||||
@@ -405,17 +405,6 @@ def parse_args():
|
||||
choices=["lora", "full"],
|
||||
help="Model structure to train. LoRA training or full training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_swanlab",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use SwanLab logger.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--swanlab_mode",
|
||||
default=None,
|
||||
help="SwanLab mode (cloud or local).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -473,20 +462,6 @@ def train(args):
|
||||
init_lora_weights=args.init_lora_weights,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing
|
||||
)
|
||||
if args.use_swanlab:
|
||||
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
||||
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
||||
swanlab_config.update(vars(args))
|
||||
swanlab_logger = SwanLabLogger(
|
||||
project="wan",
|
||||
name="wan",
|
||||
config=swanlab_config,
|
||||
mode=args.swanlab_mode,
|
||||
logdir=args.output_path,
|
||||
)
|
||||
logger = [swanlab_logger]
|
||||
else:
|
||||
logger = []
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
@@ -495,7 +470,6 @@ def train(args):
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||
logger=logger,
|
||||
)
|
||||
trainer.fit(model, dataloader)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user