diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 447c16a..3ebe524 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -156,7 +156,7 @@ if __name__ == "__main__": fp8_models=args.fp8_models, offload_models=args.offload_models, task=args.task, - device=accelerator.device, + device="cpu" if args.initialize_model_on_cpu else accelerator.device, zero_cond_t=args.zero_cond_t, ) model_logger = ModelLogger(