mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
1
diffsynth/core/gradient/__init__.py
Normal file
1
diffsynth/core/gradient/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .gradient_checkpoint import gradient_checkpoint_forward
|
||||
34
diffsynth/core/gradient/gradient_checkpoint.py
Normal file
34
diffsynth/core/gradient/gradient_checkpoint.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs, **kwargs):
|
||||
return module(*inputs, **kwargs)
|
||||
return custom_forward
|
||||
|
||||
|
||||
def gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if use_gradient_checkpointing_offload:
|
||||
with torch.autograd.graph.save_on_cpu():
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
elif use_gradient_checkpointing:
|
||||
model_output = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(model),
|
||||
*args,
|
||||
**kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
model_output = model(*args, **kwargs)
|
||||
return model_output
|
||||
Reference in New Issue
Block a user