diff --git a/examples/hunyuan_dit/train_hunyuan_dit_lora.py b/examples/hunyuan_dit/train_hunyuan_dit_lora.py index 3dc5a69..118b8fe 100644 --- a/examples/hunyuan_dit/train_hunyuan_dit_lora.py +++ b/examples/hunyuan_dit/train_hunyuan_dit_lora.py @@ -256,7 +256,7 @@ if __name__ == '__main__': # dataset and data loader dataset = TextImageDataset( args.dataset_path, - steps_per_epoch=args.steps_per_epoch, + steps_per_epoch=args.steps_per_epoch * args.batch_size, height=args.height, width=args.width, center_crop=args.center_crop,