53 lines
1.7 KiB
Python
Vendored
53 lines
1.7 KiB
Python
Vendored
import torch
|
|
######state
|
|
class TimeMixState:
|
|
def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
|
|
self.shift_state = shift_state
|
|
self.wkv_state = wkv_state
|
|
|
|
|
|
class ChannelMixState:
|
|
def __init__(self, shift_state: torch.Tensor):
|
|
self.shift_state = shift_state
|
|
|
|
|
|
class BlockState:
|
|
def __init__(self, time_mix_state: TimeMixState,
|
|
channel_mix_state: ChannelMixState):
|
|
self.time_mix_state = time_mix_state
|
|
self.channel_mix_state = channel_mix_state
|
|
|
|
class BlockStateList:
|
|
|
|
def __init__(self, shift_states, wkv_states):
|
|
self.wkv_states = wkv_states
|
|
self.shift_states = shift_states
|
|
|
|
@staticmethod
|
|
def create(N, B, C, H, device, dtype):
|
|
result = BlockStateList.empty(N, B, C, H, device, dtype)
|
|
result.wkv_states[:] = 0
|
|
result.wkv_states[:] = 0
|
|
result.shift_states[:] = 0
|
|
return result
|
|
|
|
@staticmethod
|
|
def empty(N, B, C, H, device, dtype):
|
|
wkv_states = torch.empty((N, B, H, C//H, C//H),
|
|
device=device,
|
|
dtype=torch.bfloat16)
|
|
shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype)
|
|
return BlockStateList(shift_states, wkv_states)
|
|
|
|
def __getitem__(self, layer: int):
|
|
return BlockState(
|
|
TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]),
|
|
ChannelMixState(self.shift_states[layer, 1]))
|
|
|
|
def __setitem__(self, layer: int, state: BlockState):
|
|
self.shift_states[layer, 0] = state.time_mix_state.shift_state
|
|
self.wkv_states[layer] = state.time_mix_state.wkv_state
|
|
self.shift_states[layer, 1] = state.channel_mix_state.shift_state
|
|
|
|
|