[feature]:Add adaptation of all models to zero3

This commit is contained in:
feng0w0
2026-01-27 11:24:43 +08:00
parent ffb7a138f7
commit 4e9db263b0
15 changed files with 266 additions and 34 deletions

View File

@@ -21,6 +21,7 @@ def gradient_checkpoint_forward(
*args,
**kwargs,
use_reentrant=False,
determinism_check="none"
)
elif use_gradient_checkpointing:
model_output = torch.utils.checkpoint.checkpoint(
@@ -28,6 +29,7 @@ def gradient_checkpoint_forward(
*args,
**kwargs,
use_reentrant=False,
determinism_check="none"
)
else:
model_output = model(*args, **kwargs)