This commit is contained in:
4
finetune/lora/v6/src/binidx.py
vendored
4
finetune/lora/v6/src/binidx.py
vendored
@@ -270,8 +270,10 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
np_array = np.append(np_array, np_array0)
|
||||
return np_array
|
||||
|
||||
def only(self, idx):
|
||||
def only(self, idx, length=None):
|
||||
ptr, size = self._index[idx]
|
||||
if length < size:
|
||||
size = length
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
|
||||
8
finetune/lora/v6/src/dataset.py
vendored
8
finetune/lora/v6/src/dataset.py
vendored
@@ -179,8 +179,12 @@ class MyDataset(Dataset):
|
||||
|
||||
if args.data_type == "binidx":
|
||||
if args.my_pile_version == 1:
|
||||
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
||||
# dix = data.pad(idx=idx, length=req_len).astype(int)
|
||||
if args.dataload == "pad":
|
||||
dix = data.pad(idx=idx, length=req_len).astype(int)
|
||||
elif args.dataload == "only":
|
||||
dix = data.only(idx=idx, length=req_len).astype(int)
|
||||
else:
|
||||
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
||||
else:
|
||||
# self.data : cutoff, chunk_count, data
|
||||
for j in range(len(data)):
|
||||
|
||||
52
finetune/lora/v6/src/infctx_module.py
vendored
Normal file
52
finetune/lora/v6/src/infctx_module.py
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
######state
|
||||
class TimeMixState:
|
||||
def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
|
||||
self.shift_state = shift_state
|
||||
self.wkv_state = wkv_state
|
||||
|
||||
|
||||
class ChannelMixState:
|
||||
def __init__(self, shift_state: torch.Tensor):
|
||||
self.shift_state = shift_state
|
||||
|
||||
|
||||
class BlockState:
|
||||
def __init__(self, time_mix_state: TimeMixState,
|
||||
channel_mix_state: ChannelMixState):
|
||||
self.time_mix_state = time_mix_state
|
||||
self.channel_mix_state = channel_mix_state
|
||||
|
||||
class BlockStateList:
|
||||
|
||||
def __init__(self, shift_states, wkv_states):
|
||||
self.wkv_states = wkv_states
|
||||
self.shift_states = shift_states
|
||||
|
||||
@staticmethod
|
||||
def create(N, B, C, H, device, dtype):
|
||||
result = BlockStateList.empty(N, B, C, H, device, dtype)
|
||||
result.wkv_states[:] = 0
|
||||
result.wkv_states[:] = 0
|
||||
result.shift_states[:] = 0
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def empty(N, B, C, H, device, dtype):
|
||||
wkv_states = torch.empty((N, B, H, C//H, C//H),
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype)
|
||||
return BlockStateList(shift_states, wkv_states)
|
||||
|
||||
def __getitem__(self, layer: int):
|
||||
return BlockState(
|
||||
TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]),
|
||||
ChannelMixState(self.shift_states[layer, 1]))
|
||||
|
||||
def __setitem__(self, layer: int, state: BlockState):
|
||||
self.shift_states[layer, 0] = state.time_mix_state.shift_state
|
||||
self.wkv_states[layer] = state.time_mix_state.wkv_state
|
||||
self.shift_states[layer, 1] = state.channel_mix_state.shift_state
|
||||
|
||||
|
||||
1508
finetune/lora/v6/src/model.py
vendored
1508
finetune/lora/v6/src/model.py
vendored
File diff suppressed because it is too large
Load Diff
43
finetune/lora/v6/src/trainer.py
vendored
43
finetune/lora/v6/src/trainer.py
vendored
@@ -4,6 +4,8 @@ from torch.utils.data import DataLoader
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from .model import LORA_CONFIG
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
|
||||
def my_save(args, trainer, dd, ff):
|
||||
@@ -21,10 +23,7 @@ def my_save(args, trainer, dd, ff):
|
||||
f" aws s3 mv {fff} s3://rwkv-world/{aa}-{fn} --quiet", shell=True
|
||||
)
|
||||
else:
|
||||
if "deepspeed_stage_3" in args.strategy:
|
||||
trainer.save_checkpoint(ff, weights_only=True)
|
||||
else:
|
||||
torch.save(dd, ff)
|
||||
torch.save(dd, ff)
|
||||
|
||||
|
||||
class train_callback(pl.Callback):
|
||||
@@ -181,6 +180,30 @@ class train_callback(pl.Callback):
|
||||
to_save_dict,
|
||||
f"{args.proj_dir}/rwkv-final.pth",
|
||||
)
|
||||
|
||||
if args.LISA and (batch_idx + 1) % args.lisa_k == 0:
|
||||
pl_module.requires_grad_(False)
|
||||
select_layers = np.random.choice(
|
||||
range(args.n_layer), args.lisa_r, replace=False
|
||||
)
|
||||
|
||||
for name, module in pl_module.named_modules():
|
||||
for pname, param in module.named_parameters():
|
||||
if (
|
||||
"emb" in pname
|
||||
or "head" in pname
|
||||
or ".ln" in pname
|
||||
or "time" in pname
|
||||
):
|
||||
param.requires_grad = True
|
||||
elif "ln_out" in pname:
|
||||
param.requires_grad = True
|
||||
match = re.search(r"\d+", pname)
|
||||
if match:
|
||||
number = int(match.group())
|
||||
if number in select_layers:
|
||||
param.requires_grad = True
|
||||
break
|
||||
# if args.batch_save==batch_idx :
|
||||
# to_save_dict = pl_module.state_dict()
|
||||
# for name, state in to_save_dict.items():
|
||||
@@ -229,12 +252,22 @@ class train_callback(pl.Callback):
|
||||
if "img" in name:
|
||||
to_save_dict[name] = state
|
||||
|
||||
if args.state_tune or args.train_type == "state":
|
||||
lora_dict = {}
|
||||
for name, state in to_save_dict.items():
|
||||
if "state" in name:
|
||||
lora_dict[name] = state
|
||||
to_save_dict = lora_dict
|
||||
|
||||
if args.lora:
|
||||
enable_time_finetune = "time" in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = "ln" in LORA_CONFIG["parts"]
|
||||
lora_dict = {}
|
||||
for name, state in to_save_dict.items():
|
||||
if "img" in name:
|
||||
if len(args.load_model) == 0:
|
||||
if "emb" in name or "head" in name or "ln" in name:
|
||||
lora_dict[name] = state
|
||||
if args.emb and "emb" in name:
|
||||
lora_dict[name] = state
|
||||
if (
|
||||
".lora_" in name
|
||||
|
||||
Reference in New Issue
Block a user