mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
66 lines
1.8 KiB
Python
66 lines
1.8 KiB
Python
import torch
|
|
|
|
|
|
try:
|
|
import deepspeed
|
|
_HAS_DEEPSPEED = True
|
|
except ModuleNotFoundError:
|
|
_HAS_DEEPSPEED = False
|
|
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs, **kwargs):
|
|
return module(*inputs, **kwargs)
|
|
return custom_forward
|
|
|
|
|
|
def create_custom_forward_use_reentrant(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
return custom_forward
|
|
|
|
|
|
def judge_args_requires_grad(*args):
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor) and arg.requires_grad:
|
|
return True
|
|
return False
|
|
|
|
|
|
def gradient_checkpoint_forward(
|
|
model,
|
|
use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
if use_gradient_checkpointing and _HAS_DEEPSPEED and deepspeed.checkpointing.is_configured():
|
|
all_args = args + tuple(kwargs.values())
|
|
if not judge_args_requires_grad(*all_args):
|
|
# get the first grad_enabled tensor from un_checkpointed forward
|
|
model_output = model(*args, **kwargs)
|
|
else:
|
|
model_output = deepspeed.checkpointing.checkpoint(
|
|
create_custom_forward_use_reentrant(model),
|
|
*all_args,
|
|
)
|
|
return model_output
|
|
if use_gradient_checkpointing_offload:
|
|
with torch.autograd.graph.save_on_cpu():
|
|
model_output = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(model),
|
|
*args,
|
|
**kwargs,
|
|
use_reentrant=False,
|
|
)
|
|
elif use_gradient_checkpointing:
|
|
model_output = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(model),
|
|
*args,
|
|
**kwargs,
|
|
use_reentrant=False,
|
|
)
|
|
else:
|
|
model_output = model(*args, **kwargs)
|
|
return model_output
|