diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index e2a60a7..568c77a 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -119,11 +119,4 @@ if __name__ == "__main__": ) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) - launch_training_task( - dataset, model, model_logger, optimizer, scheduler, - num_epochs=args.num_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - save_steps=args.save_steps, - find_unused_parameters=args.find_unused_parameters, - num_workers=args.dataset_num_workers, - ) + launch_training_task(dataset, model, model_logger, args=args) diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index b811f67..b9b7d8c 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -128,11 +128,4 @@ if __name__ == "__main__": ) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) - launch_training_task( - dataset, model, model_logger, optimizer, scheduler, - num_epochs=args.num_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - save_steps=args.save_steps, - find_unused_parameters=args.find_unused_parameters, - num_workers=args.dataset_num_workers, - ) + launch_training_task(dataset, model, model_logger, args=args)