RWKV-Runner/finetune/lora/v6/fla/utils.py

34 lines
1.0 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
import functools
import torch
def contiguous(fn):
@functools.wraps(fn)
def wrapper(ctx, *args, **kwargs):
return fn(ctx,
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
return wrapper
def require_version(version, hint):
def decorator(fn):
@functools.wraps(fn)
def wrapper(ctx, *args, **kwargs):
from transformers.utils.versions import require_version
require_version(version, hint)
return fn(ctx,
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
return wrapper
return decorator
def checkpoint(func):
def wrapper(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(func, *args, **kwargs)
return wrapper