mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
update swanlab log
This commit is contained in:
@@ -256,29 +256,11 @@ def add_general_parsers(parser):
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to use SwanLab logger.",
|
help="Whether to use SwanLab logger.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--swanlab_project",
|
|
||||||
type=str,
|
|
||||||
default="diffsynth_studio",
|
|
||||||
help="SwanLab project name.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--swanlab_name",
|
|
||||||
type=str,
|
|
||||||
default="diffsynth_studio_train",
|
|
||||||
help="SwanLab experimentname.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--swanlab_mode",
|
"--swanlab_mode",
|
||||||
default=None,
|
default=None,
|
||||||
help="SwanLab mode (cloud or local).",
|
help="SwanLab mode (cloud or local).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--swanlab_logdir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="SwanLab local log directory.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@@ -301,18 +283,20 @@ def launch_training_task(model, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# set swanlab logger
|
# set swanlab logger
|
||||||
swanlab_logger = None
|
|
||||||
if args.use_swanlab:
|
if args.use_swanlab:
|
||||||
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
||||||
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
||||||
swanlab_config.update(vars(args))
|
swanlab_config.update(vars(args))
|
||||||
swanlab_logger = SwanLabLogger(
|
swanlab_logger = SwanLabLogger(
|
||||||
project=args.swanlab_project,
|
project="diffsynth_studio",
|
||||||
name=args.swanlab_name,
|
name="diffsynth_studio",
|
||||||
config=swanlab_config,
|
config=swanlab_config,
|
||||||
mode=args.swanlab_mode,
|
mode=args.swanlab_mode,
|
||||||
logdir=args.swanlab_logdir,
|
logdir=args.output_path,
|
||||||
)
|
)
|
||||||
|
logger = [swanlab_config]
|
||||||
|
else:
|
||||||
|
logger = []
|
||||||
|
|
||||||
# train
|
# train
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
@@ -324,7 +308,7 @@ def launch_training_task(model, args):
|
|||||||
default_root_dir=args.output_path,
|
default_root_dir=args.output_path,
|
||||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||||
logger=[swanlab_logger],
|
logger=logger,
|
||||||
)
|
)
|
||||||
trainer.fit(model=model, train_dataloaders=train_loader)
|
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||||
|
|
||||||
|
|||||||
@@ -400,29 +400,11 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to use SwanLab logger.",
|
help="Whether to use SwanLab logger.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--swanlab_project",
|
|
||||||
type=str,
|
|
||||||
default="wan_t2v",
|
|
||||||
help="SwanLab project name.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--swanlab_name",
|
|
||||||
type=str,
|
|
||||||
default="wan_t2v_train",
|
|
||||||
help="SwanLab experimentname.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--swanlab_mode",
|
"--swanlab_mode",
|
||||||
default=None,
|
default=None,
|
||||||
help="SwanLab mode (cloud or local).",
|
help="SwanLab mode (cloud or local).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--swanlab_logdir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="SwanLab local log directory.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@@ -493,18 +475,20 @@ def train(args):
|
|||||||
init_lora_weights=args.init_lora_weights,
|
init_lora_weights=args.init_lora_weights,
|
||||||
use_gradient_checkpointing=args.use_gradient_checkpointing
|
use_gradient_checkpointing=args.use_gradient_checkpointing
|
||||||
)
|
)
|
||||||
swanlab_logger = None
|
|
||||||
if args.use_swanlab:
|
if args.use_swanlab:
|
||||||
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
||||||
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
||||||
swanlab_config.update(vars(args))
|
swanlab_config.update(vars(args))
|
||||||
swanlab_logger = SwanLabLogger(
|
swanlab_logger = SwanLabLogger(
|
||||||
project=args.swanlab_project,
|
project="wan",
|
||||||
name=args.swanlab_name,
|
name="wan",
|
||||||
config=swanlab_config,
|
config=swanlab_config,
|
||||||
mode=args.swanlab_mode,
|
mode=args.swanlab_mode,
|
||||||
logdir=args.swanlab_logdir,
|
logdir=args.output_path,
|
||||||
)
|
)
|
||||||
|
logger = [swanlab_logger]
|
||||||
|
else:
|
||||||
|
logger = []
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
accelerator="gpu",
|
accelerator="gpu",
|
||||||
@@ -513,7 +497,7 @@ def train(args):
|
|||||||
default_root_dir=args.output_path,
|
default_root_dir=args.output_path,
|
||||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||||
logger=[swanlab_logger],
|
logger=logger,
|
||||||
)
|
)
|
||||||
trainer.fit(model, dataloader)
|
trainer.fit(model, dataloader)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user