This commit is contained in:
107
finetune/lora/v6/fla/models/utils.py
vendored
Normal file
107
finetune/lora/v6/fla/models/utils.py
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
|
||||
class RecurrentCache(Cache):
|
||||
"""
|
||||
A cache used for storing hidden states produced by flash linear attention models.
|
||||
|
||||
It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seen_tokens: int = 0
|
||||
) -> RecurrentCache:
|
||||
|
||||
self.states: List[torch.Tensor] = []
|
||||
self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> torch.Tensor:
|
||||
if layer_idx < len(self):
|
||||
return self.states[layer_idx]
|
||||
else:
|
||||
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
||||
|
||||
def __iter__(self):
|
||||
for state in self.states:
|
||||
yield state
|
||||
|
||||
def __len__(self):
|
||||
return len(self.states)
|
||||
|
||||
def update(
|
||||
self,
|
||||
state: Tuple[torch.Tensor],
|
||||
layer_idx: int,
|
||||
offset: Optional[int] = 1,
|
||||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
"""
|
||||
Updates the cache with the new `state` for the layer `layer_idx`.
|
||||
|
||||
Parameters:
|
||||
state (`Tuple[torch.Tensor]`):
|
||||
The new state to cache.
|
||||
layer_idx (`int`):
|
||||
The index of the layer to cache the states for.
|
||||
offset (`int`):
|
||||
The offset of current fed tokens.
|
||||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||||
Additional arguments for the cache subclass.
|
||||
|
||||
Return:
|
||||
The updated state.
|
||||
"""
|
||||
|
||||
if isinstance(state, torch.Tensor):
|
||||
state = (state,)
|
||||
if len(self.states) <= layer_idx:
|
||||
self.states.append(state)
|
||||
else:
|
||||
for i, s in enumerate(state):
|
||||
self.states[layer_idx][i].copy_(s)
|
||||
# update the number of seen tokens once we achieve the last layer
|
||||
if layer_idx == len(self) - 1:
|
||||
self._seen_tokens += offset
|
||||
|
||||
return state
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||
if len(self.states) <= layer_idx:
|
||||
return 0
|
||||
return self._seen_tokens
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
"""Returns the maximum sequence length of the cached states. RecurrentCache does not have a maximum length."""
|
||||
return None
|
||||
|
||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||
for layer_idx in range(len(self.states)):
|
||||
device = self.states[layer_idx].device
|
||||
self.states[layer_idx] = self.states[layer_idx].index_select(0, beam_idx.to(device))
|
||||
|
||||
def to_legacy_cache(self) -> Tuple[torch.Tensor]:
|
||||
return tuple(self.states)
|
||||
|
||||
@classmethod
|
||||
def from_legacy_cache(
|
||||
cls,
|
||||
past_key_values: Optional[Tuple[torch.Tensor]] = None,
|
||||
seen_tokens: int = 0
|
||||
) -> RecurrentCache:
|
||||
"""Converts a cache in the legacy cache format into an equivalent `RecurrentCache`."""
|
||||
|
||||
cache = cls(seen_tokens)
|
||||
if past_key_values is not None:
|
||||
for layer_idx in range(len(past_key_values)):
|
||||
cache.update(past_key_values[layer_idx], layer_idx)
|
||||
return cache
|
||||
Reference in New Issue
Block a user