This commit is contained in:
josc146
2024-05-28 22:35:47 +08:00
parent 3488d22d22
commit f05a4acb04
138 changed files with 29047 additions and 334 deletions

33
finetune/lora/v6/fla/utils.py vendored Normal file
View File

@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
import functools
import torch
def contiguous(fn):
@functools.wraps(fn)
def wrapper(ctx, *args, **kwargs):
return fn(ctx,
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
return wrapper
def require_version(version, hint):
def decorator(fn):
@functools.wraps(fn)
def wrapper(ctx, *args, **kwargs):
from transformers.utils.versions import require_version
require_version(version, hint)
return fn(ctx,
*(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
**{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
return wrapper
return decorator
def checkpoint(func):
def wrapper(*args, **kwargs):
return torch.utils.checkpoint.checkpoint(func, *args, **kwargs)
return wrapper