diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index 568c77a..4a82228 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -117,6 +117,4 @@ if __name__ == "__main__": remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, ) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) - scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) launch_training_task(dataset, model, model_logger, args=args) diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index b9b7d8c..37494e7 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -126,6 +126,4 @@ if __name__ == "__main__": args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt ) - optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) - scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) launch_training_task(dataset, model, model_logger, args=args)