This commit is contained in:
33
finetune/lora/v6/fla/utils.py
vendored
Normal file
33
finetune/lora/v6/fla/utils.py
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def contiguous(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(ctx, *args, **kwargs):
|
||||
return fn(ctx,
|
||||
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
|
||||
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_version(version, hint):
|
||||
def decorator(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(ctx, *args, **kwargs):
|
||||
from transformers.utils.versions import require_version
|
||||
require_version(version, hint)
|
||||
return fn(ctx,
|
||||
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
|
||||
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def checkpoint(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(func, *args, **kwargs)
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user