34 lines
1.0 KiB
Python
34 lines
1.0 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
import functools
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def contiguous(fn):
|
||
|
@functools.wraps(fn)
|
||
|
def wrapper(ctx, *args, **kwargs):
|
||
|
return fn(ctx,
|
||
|
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
|
||
|
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def require_version(version, hint):
|
||
|
def decorator(fn):
|
||
|
@functools.wraps(fn)
|
||
|
def wrapper(ctx, *args, **kwargs):
|
||
|
from transformers.utils.versions import require_version
|
||
|
require_version(version, hint)
|
||
|
return fn(ctx,
|
||
|
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
|
||
|
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
|
||
|
return wrapper
|
||
|
return decorator
|
||
|
|
||
|
|
||
|
def checkpoint(func):
|
||
|
def wrapper(*args, **kwargs):
|
||
|
return torch.utils.checkpoint.checkpoint(func, *args, **kwargs)
|
||
|
return wrapper
|