[feature]:Add adaptation of all models to zero3

This commit is contained in:
feng0w0
2026-02-03 15:44:53 +08:00
parent 2070bbd925
commit ca9b5e64ea
4 changed files with 9 additions and 14 deletions

View File

@@ -21,7 +21,6 @@ def gradient_checkpoint_forward(
*args,
**kwargs,
use_reentrant=False,
determinism_check="none"
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
@@ -29,7 +28,6 @@ def gradient_checkpoint_forward(
*args,
**kwargs,
use_reentrant=False,
determinism_check="none"
)
else:
model_output = model(*args, **kwargs)