# -*- coding: utf-8 -*- import functools import torch def contiguous(fn): @functools.wraps(fn) def wrapper(ctx, *args, **kwargs): return fn(ctx, *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) return wrapper def require_version(version, hint): def decorator(fn): @functools.wraps(fn) def wrapper(ctx, *args, **kwargs): from transformers.utils.versions import require_version require_version(version, hint) return fn(ctx, *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) return wrapper return decorator def checkpoint(func): def wrapper(*args, **kwargs): return torch.utils.checkpoint.checkpoint(func, *args, **kwargs) return wrapper