mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
37 lines
985 B
Python
37 lines
985 B
Python
import torch
|
|
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs, **kwargs):
|
|
return module(*inputs, **kwargs)
|
|
return custom_forward
|
|
|
|
|
|
def gradient_checkpoint_forward(
|
|
model,
|
|
use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
if use_gradient_checkpointing_offload:
|
|
with torch.autograd.graph.save_on_cpu():
|
|
model_output = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(model),
|
|
*args,
|
|
**kwargs,
|
|
use_reentrant=False,
|
|
determinism_check="none"
|
|
)
|
|
elif use_gradient_checkpointing:
|
|
model_output = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(model),
|
|
*args,
|
|
**kwargs,
|
|
use_reentrant=False,
|
|
determinism_check="none"
|
|
)
|
|
else:
|
|
model_output = model(*args, **kwargs)
|
|
return model_output
|