lora finetune (need to be refactored)
This commit is contained in:
parent
c54d10795f
commit
987854fe49
2
.gitattributes
vendored
2
.gitattributes
vendored
@ -3,4 +3,6 @@ backend-python/wkv_cuda_utils/** linguist-vendored
|
||||
backend-python/get-pip.py linguist-vendored
|
||||
backend-python/convert_model.py linguist-vendored
|
||||
build/** linguist-vendored
|
||||
finetune/lora/** linguist-vendored
|
||||
finetune/json2binidx_tool/** linguist-vendored
|
||||
frontend/wailsjs/** linguist-generated
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -20,4 +20,6 @@ __pycache__
|
||||
*.old
|
||||
.DS_Store
|
||||
*.log.*
|
||||
*.log
|
||||
*.log
|
||||
train_log.txt
|
||||
finetune/json2binidx_tool/data
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/minio/selfupdate"
|
||||
wruntime "github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
)
|
||||
@ -41,6 +42,27 @@ func (a *App) OnStartup(ctx context.Context) {
|
||||
}
|
||||
|
||||
a.downloadLoop()
|
||||
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err == nil {
|
||||
watcher.Add("./lora-models")
|
||||
watcher.Add("./models")
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
wruntime.EventsEmit(ctx, "fsnotify", event.Name)
|
||||
case _, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) UpdateApp(url string) (broken bool, err error) {
|
||||
|
@ -31,6 +31,38 @@ func (a *App) ConvertModel(python string, modelPath string, strategy string, out
|
||||
return Cmd(python, "./backend-python/convert_model.py", "--in", modelPath, "--out", outPath, "--strategy", strategy)
|
||||
}
|
||||
|
||||
func (a *App) ConvertData(python string, input string, outputPrefix string, vocab string) (string, error) {
|
||||
var err error
|
||||
if python == "" {
|
||||
python, err = GetPython()
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
tokenizerType := "HFTokenizer"
|
||||
if strings.Contains(vocab, "rwkv_vocab_v20230424") {
|
||||
tokenizerType = "RWKVTokenizer"
|
||||
}
|
||||
return Cmd(python, "./finetune/json2binidx_tool/tools/preprocess_data.py", "--input", input, "--output-prefix", outputPrefix, "--vocab", vocab,
|
||||
"--tokenizer-type", tokenizerType, "--dataset-impl", "mmap", "--append-eod")
|
||||
}
|
||||
|
||||
func (a *App) MergeLora(python string, useGpu bool, loraAlpha int, baseModel string, loraPath string, outputPath string) (string, error) {
|
||||
var err error
|
||||
if python == "" {
|
||||
python, err = GetPython()
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
args := []string{python, "./finetune/lora/merge_lora.py"}
|
||||
if useGpu {
|
||||
args = append(args, "--use-gpu")
|
||||
}
|
||||
args = append(args, strconv.Itoa(loraAlpha), baseModel, loraPath, outputPath)
|
||||
return Cmd(args...)
|
||||
}
|
||||
|
||||
func (a *App) DepCheck(python string) error {
|
||||
var err error
|
||||
if python == "" {
|
||||
|
202
backend-golang/wsl.go
Normal file
202
backend-golang/wsl.go
Normal file
@ -0,0 +1,202 @@
|
||||
package backend_golang
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
su "github.com/nyaosorg/go-windows-su"
|
||||
wsl "github.com/ubuntu/gowsl"
|
||||
wruntime "github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
)
|
||||
|
||||
var distro *wsl.Distro
|
||||
var stdin io.WriteCloser
|
||||
var cmd *exec.Cmd
|
||||
|
||||
func isWslRunning() (bool, error) {
|
||||
if distro == nil {
|
||||
return false, nil
|
||||
}
|
||||
state, err := distro.State()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if state != wsl.Running {
|
||||
distro = nil
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (a *App) WslStart() error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return errors.New("wsl not supported")
|
||||
}
|
||||
|
||||
running, err := isWslRunning()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if running {
|
||||
return nil
|
||||
}
|
||||
distros, err := wsl.RegisteredDistros(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, d := range distros {
|
||||
if strings.Contains(d.Name(), "Ubuntu") {
|
||||
distro = &d
|
||||
break
|
||||
}
|
||||
}
|
||||
if distro == nil {
|
||||
return errors.New("ubuntu not found")
|
||||
}
|
||||
|
||||
cmd = exec.Command("wsl", "-d", distro.Name(), "-u", "root")
|
||||
|
||||
stdin, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
cmd.Stderr = cmd.Stdout
|
||||
if err != nil {
|
||||
// stdin.Close()
|
||||
stdin = nil
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
reader := bufio.NewReader(stdout)
|
||||
for {
|
||||
if stdin == nil {
|
||||
break
|
||||
}
|
||||
line, _, err := reader.ReadLine()
|
||||
if err != nil {
|
||||
wruntime.EventsEmit(a.ctx, "wslerr", err.Error())
|
||||
break
|
||||
}
|
||||
wruntime.EventsEmit(a.ctx, "wsl", string(line))
|
||||
}
|
||||
// stdout.Close()
|
||||
}()
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) WslCommand(command string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return errors.New("wsl not supported")
|
||||
}
|
||||
|
||||
running, err := isWslRunning()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !running {
|
||||
return errors.New("wsl not running")
|
||||
}
|
||||
_, err = stdin.Write([]byte(command + "\n"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) WslStop() error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return errors.New("wsl not supported")
|
||||
}
|
||||
|
||||
running, err := isWslRunning()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !running {
|
||||
return errors.New("wsl not running")
|
||||
}
|
||||
err = cmd.Process.Kill()
|
||||
cmd = nil
|
||||
// stdin.Close()
|
||||
stdin = nil
|
||||
distro = nil
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) WslIsEnabled() error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return errors.New("wsl not supported")
|
||||
}
|
||||
|
||||
ex, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exDir := filepath.Dir(ex)
|
||||
|
||||
data, err := os.ReadFile(exDir + "/wsl.state")
|
||||
if err == nil {
|
||||
if strings.Contains(string(data), "Enabled") {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
cmd := `-Command (Get-WindowsOptionalFeature -Online -FeatureName Microsoft-Windows-Subsystem-Linux).State | Out-File -Encoding utf8 -FilePath ` + exDir + "/wsl.state"
|
||||
_, err = su.ShellExecute(su.RUNAS, "powershell", cmd, exDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
data, err = os.ReadFile(exDir + "/wsl.state")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.Contains(string(data), "Enabled") {
|
||||
return nil
|
||||
} else {
|
||||
return errors.New("wsl is not enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) WslEnable(forceMode bool) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return errors.New("wsl not supported")
|
||||
}
|
||||
|
||||
cmd := `/online /enable-feature /featurename:Microsoft-Windows-Subsystem-Linux`
|
||||
_, err := su.ShellExecute(su.RUNAS, "dism", cmd, `C:\`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if forceMode {
|
||||
os.WriteFile("./wsl.state", []byte("Enabled"), 0644)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) WslInstallUbuntu() error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return errors.New("wsl not supported")
|
||||
}
|
||||
|
||||
exec.Command("start", "ms-windows-store://pdp/?ProductId=9PN20MSR04DW").Start()
|
||||
return nil
|
||||
}
|
@ -1,8 +1,13 @@
|
||||
import lm_dataformat
|
||||
import ftfy
|
||||
import tqdm
|
||||
import tiktoken
|
||||
import GPUtil
|
||||
|
||||
import torch
|
||||
import rwkv
|
||||
import numpy
|
||||
import tokenizers
|
||||
import fastapi
|
||||
import uvicorn
|
||||
import sse_starlette
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -95,27 +95,64 @@ async def eval_rwkv(
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
completion_lock.acquire()
|
||||
if await request.is_disconnected():
|
||||
completion_lock.release()
|
||||
with completion_lock:
|
||||
if await request.is_disconnected():
|
||||
requests_num = requests_num - 1
|
||||
print(f"{request.client} Stop Waiting (Lock)")
|
||||
quick_log(
|
||||
request,
|
||||
None,
|
||||
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
|
||||
)
|
||||
return
|
||||
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
||||
set_rwkv_config(model, body)
|
||||
|
||||
response, prompt_tokens, completion_tokens = "", 0, 0
|
||||
for response, delta, prompt_tokens, completion_tokens in model.generate(
|
||||
prompt,
|
||||
stop=stop,
|
||||
):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
if stream:
|
||||
yield json.dumps(
|
||||
{
|
||||
"object": "chat.completion.chunk"
|
||||
if chat_mode
|
||||
else "text_completion",
|
||||
"response": response,
|
||||
"model": model.name,
|
||||
"choices": [
|
||||
{
|
||||
"delta": {"content": delta},
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
}
|
||||
if chat_mode
|
||||
else {
|
||||
"text": delta,
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
# torch_gc()
|
||||
requests_num = requests_num - 1
|
||||
print(f"{request.client} Stop Waiting (Lock)")
|
||||
if await request.is_disconnected():
|
||||
print(f"{request.client} Stop Waiting")
|
||||
quick_log(
|
||||
request,
|
||||
body,
|
||||
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
|
||||
)
|
||||
return
|
||||
quick_log(
|
||||
request,
|
||||
None,
|
||||
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
|
||||
body,
|
||||
response + "\nFinished. RequestsNum: " + str(requests_num),
|
||||
)
|
||||
return
|
||||
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
||||
set_rwkv_config(model, body)
|
||||
|
||||
response, prompt_tokens, completion_tokens = "", 0, 0
|
||||
for response, delta, prompt_tokens, completion_tokens in model.generate(
|
||||
prompt,
|
||||
stop=stop,
|
||||
):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
if stream:
|
||||
yield json.dumps(
|
||||
{
|
||||
@ -126,86 +163,47 @@ async def eval_rwkv(
|
||||
"model": model.name,
|
||||
"choices": [
|
||||
{
|
||||
"delta": {"content": delta},
|
||||
"delta": {},
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
if chat_mode
|
||||
else {
|
||||
"text": delta,
|
||||
"text": "",
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
# torch_gc()
|
||||
requests_num = requests_num - 1
|
||||
completion_lock.release()
|
||||
if await request.is_disconnected():
|
||||
print(f"{request.client} Stop Waiting")
|
||||
quick_log(
|
||||
request,
|
||||
body,
|
||||
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
|
||||
)
|
||||
return
|
||||
quick_log(
|
||||
request,
|
||||
body,
|
||||
response + "\nFinished. RequestsNum: " + str(requests_num),
|
||||
)
|
||||
if stream:
|
||||
yield json.dumps(
|
||||
{
|
||||
"object": "chat.completion.chunk"
|
||||
if chat_mode
|
||||
else "text_completion",
|
||||
yield "[DONE]"
|
||||
else:
|
||||
yield {
|
||||
"object": "chat.completion" if chat_mode else "text_completion",
|
||||
"response": response,
|
||||
"model": model.name,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
"choices": [
|
||||
{
|
||||
"delta": {},
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
},
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
if chat_mode
|
||||
else {
|
||||
"text": "",
|
||||
"text": response,
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
yield "[DONE]"
|
||||
else:
|
||||
yield {
|
||||
"object": "chat.completion" if chat_mode else "text_completion",
|
||||
"response": response,
|
||||
"model": model.name,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
},
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
if chat_mode
|
||||
else {
|
||||
"text": response,
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@ -372,81 +370,88 @@ async def embeddings(body: EmbeddingsBody, request: Request):
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
completion_lock.acquire()
|
||||
if await request.is_disconnected():
|
||||
completion_lock.release()
|
||||
requests_num = requests_num - 1
|
||||
print(f"{request.client} Stop Waiting (Lock)")
|
||||
quick_log(
|
||||
request,
|
||||
None,
|
||||
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
|
||||
)
|
||||
return
|
||||
with completion_lock:
|
||||
if await request.is_disconnected():
|
||||
requests_num = requests_num - 1
|
||||
print(f"{request.client} Stop Waiting (Lock)")
|
||||
quick_log(
|
||||
request,
|
||||
None,
|
||||
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
|
||||
)
|
||||
return
|
||||
|
||||
base64_format = False
|
||||
if body.encoding_format == "base64":
|
||||
base64_format = True
|
||||
base64_format = False
|
||||
if body.encoding_format == "base64":
|
||||
base64_format = True
|
||||
|
||||
embeddings = []
|
||||
prompt_tokens = 0
|
||||
if type(body.input) == list:
|
||||
if type(body.input[0]) == list:
|
||||
encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002")
|
||||
for i in range(len(body.input)):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
input = encoding.decode(body.input[i])
|
||||
embedding, token_len = model.get_embedding(input, body.fast_mode)
|
||||
prompt_tokens = prompt_tokens + token_len
|
||||
if base64_format:
|
||||
embedding = embedding_base64(embedding)
|
||||
embeddings.append(embedding)
|
||||
else:
|
||||
for i in range(len(body.input)):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
embedding, token_len = model.get_embedding(
|
||||
body.input[i], body.fast_mode
|
||||
embeddings = []
|
||||
prompt_tokens = 0
|
||||
if type(body.input) == list:
|
||||
if type(body.input[0]) == list:
|
||||
encoding = tiktoken.model.encoding_for_model(
|
||||
"text-embedding-ada-002"
|
||||
)
|
||||
prompt_tokens = prompt_tokens + token_len
|
||||
if base64_format:
|
||||
embedding = embedding_base64(embedding)
|
||||
embeddings.append(embedding)
|
||||
else:
|
||||
embedding, prompt_tokens = model.get_embedding(body.input, body.fast_mode)
|
||||
if base64_format:
|
||||
embedding = embedding_base64(embedding)
|
||||
embeddings.append(embedding)
|
||||
for i in range(len(body.input)):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
input = encoding.decode(body.input[i])
|
||||
embedding, token_len = model.get_embedding(
|
||||
input, body.fast_mode
|
||||
)
|
||||
prompt_tokens = prompt_tokens + token_len
|
||||
if base64_format:
|
||||
embedding = embedding_base64(embedding)
|
||||
embeddings.append(embedding)
|
||||
else:
|
||||
for i in range(len(body.input)):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
embedding, token_len = model.get_embedding(
|
||||
body.input[i], body.fast_mode
|
||||
)
|
||||
prompt_tokens = prompt_tokens + token_len
|
||||
if base64_format:
|
||||
embedding = embedding_base64(embedding)
|
||||
embeddings.append(embedding)
|
||||
else:
|
||||
embedding, prompt_tokens = model.get_embedding(
|
||||
body.input, body.fast_mode
|
||||
)
|
||||
if base64_format:
|
||||
embedding = embedding_base64(embedding)
|
||||
embeddings.append(embedding)
|
||||
|
||||
requests_num = requests_num - 1
|
||||
completion_lock.release()
|
||||
if await request.is_disconnected():
|
||||
print(f"{request.client} Stop Waiting")
|
||||
requests_num = requests_num - 1
|
||||
if await request.is_disconnected():
|
||||
print(f"{request.client} Stop Waiting")
|
||||
quick_log(
|
||||
request,
|
||||
None,
|
||||
"Stop Waiting. RequestsNum: " + str(requests_num),
|
||||
)
|
||||
return
|
||||
quick_log(
|
||||
request,
|
||||
None,
|
||||
"Stop Waiting. RequestsNum: " + str(requests_num),
|
||||
"Finished. RequestsNum: " + str(requests_num),
|
||||
)
|
||||
return
|
||||
quick_log(
|
||||
request,
|
||||
None,
|
||||
"Finished. RequestsNum: " + str(requests_num),
|
||||
)
|
||||
|
||||
ret_data = [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": i,
|
||||
"embedding": embedding,
|
||||
ret_data = [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": i,
|
||||
"embedding": embedding,
|
||||
}
|
||||
for i, embedding in enumerate(embeddings)
|
||||
]
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": ret_data,
|
||||
"model": model.name,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"total_tokens": prompt_tokens,
|
||||
},
|
||||
}
|
||||
for i, embedding in enumerate(embeddings)
|
||||
]
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": ret_data,
|
||||
"model": model.name,
|
||||
"usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
|
||||
}
|
||||
|
7
finetune/data/sample.jsonl
Normal file
7
finetune/data/sample.jsonl
Normal file
@ -0,0 +1,7 @@
|
||||
{"text": "1:This is the first document."}
|
||||
{"text": "2:Hello\nWorld"}
|
||||
{"text": "3:1+1=2\n1+2=3\n2+2=4"}
|
||||
{"text": "4:You will be training the GPT version because it's paralleziable and faster to train."}
|
||||
{"text": "5:Read the inference code in src/model.py and try using the final hidden state(.xx .aa .bb)"}
|
||||
{"text": "6:You can fine-tune the model with longer ctxLen and it can quickly adapt to longer ctxLens."}
|
||||
{"text": "7:Consider RWKV 14B. The state has 200 vectors, that is, 5 vectors for each block: fp16 (xx), fp32 (aa), fp32 (bb), fp32 (pp), fp16 (xx)."}
|
41
finetune/get_layer_and_embd.py
Normal file
41
finetune/get_layer_and_embd.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
import threading
|
||||
import gc
|
||||
|
||||
|
||||
def file_cleaner(file):
|
||||
last_pos = 0
|
||||
|
||||
def cleaner():
|
||||
nonlocal last_pos
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
pos = file.tell()
|
||||
if pos > last_pos:
|
||||
os.posix_fadvise(
|
||||
file.fileno(), last_pos, pos - last_pos, os.POSIX_FADV_DONTNEED
|
||||
)
|
||||
last_pos = pos
|
||||
|
||||
return cleaner
|
||||
|
||||
|
||||
model_file = open(sys.argv[1], "rb")
|
||||
cleaner = file_cleaner(model_file)
|
||||
cleaner_thread = threading.Thread(target=cleaner, daemon=True)
|
||||
cleaner_thread.start()
|
||||
|
||||
w = torch.load(model_file, map_location="cpu")
|
||||
gc.collect()
|
||||
|
||||
n_embd = w["emb.weight"].shape[1]
|
||||
n_layer = 0
|
||||
keys = list(w.keys())
|
||||
for x in keys:
|
||||
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
|
||||
n_layer = max(n_layer, layer_id + 1)
|
||||
|
||||
print(f"--n_layer {n_layer} --n_embd {n_embd}", end="")
|
46
finetune/install-wsl-dep-and-train.sh
Normal file
46
finetune/install-wsl-dep-and-train.sh
Normal file
@ -0,0 +1,46 @@
|
||||
if [[ ${cnMirror} == 1 ]]; then
|
||||
export PIP_INDEX_URL="https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
if grep -q "mirrors.aliyun.com" /etc/apt/sources.list; then
|
||||
echo "apt cnMirror already set"
|
||||
else
|
||||
sudo sed -i 's/http:\/\/archive.ubuntu.com\/ubuntu\//http:\/\/mirrors.aliyun.com\/ubuntu\//g' /etc/apt/sources.list
|
||||
sudo apt update
|
||||
fi
|
||||
fi
|
||||
|
||||
if dpkg -s "python3-pip" >/dev/null 2>&1; then
|
||||
echo "pip installed"
|
||||
else
|
||||
sudo apt install python3-pip
|
||||
fi
|
||||
|
||||
if dpkg -s "ninja-build" >/dev/null 2>&1; then
|
||||
echo "ninja installed"
|
||||
else
|
||||
sudo apt install ninja-build
|
||||
fi
|
||||
|
||||
if dpkg -s "cuda" >/dev/null 2>&1; then
|
||||
echo "cuda installed"
|
||||
else
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin
|
||||
sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600
|
||||
wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb
|
||||
sudo dpkg -i cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb
|
||||
sudo cp /var/cuda-repo-wsl-ubuntu-11-7-local/cuda-*-keyring.gpg /usr/share/keyrings/
|
||||
sudo apt-get update
|
||||
sudo apt-get -y install cuda
|
||||
fi
|
||||
|
||||
if python3 -c "import pkg_resources; pkg_resources.require(open('./finetune/requirements.txt',mode='r'))" &>/dev/null; then
|
||||
echo "requirements satisfied"
|
||||
else
|
||||
python3 -m pip install -r ./finetune/requirements.txt
|
||||
fi
|
||||
|
||||
echo "loading $loadModel"
|
||||
modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel)
|
||||
echo $modelInfo
|
||||
|
||||
python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
|
||||
--lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu
|
597
finetune/json2binidx_tool/tools/indexed_dataset.py
vendored
Normal file
597
finetune/json2binidx_tool/tools/indexed_dataset.py
vendored
Normal file
@ -0,0 +1,597 @@
|
||||
# Copyright (c) 2021, EleutherAI
|
||||
# This file is based on code by the authors denoted below and has been modified from its original version.
|
||||
#
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# copied from fairseq/fairseq/data/indexed_dataset.py
|
||||
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
|
||||
# other slight modifications to remove fairseq dependencies
|
||||
# Added document index to index file and made it accessible.
|
||||
# An empty sentence no longer separates documents.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from itertools import accumulate
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
|
||||
def __best_fitting_dtype(vocab_size=None):
|
||||
if vocab_size is not None and vocab_size < 65500:
|
||||
return np.uint16
|
||||
else:
|
||||
return np.int32
|
||||
|
||||
|
||||
def infer_dataset_impl(path):
|
||||
if IndexedDataset.exists(path):
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
if magic == IndexedDataset._HDR_MAGIC:
|
||||
return "cached"
|
||||
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
|
||||
return "mmap"
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print(
|
||||
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def make_builder(out_file, impl, vocab_size=None):
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDatasetBuilder(
|
||||
out_file, dtype=__best_fitting_dtype(vocab_size)
|
||||
)
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
|
||||
|
||||
def make_dataset(path, impl, skip_warmup=False):
|
||||
if not IndexedDataset.exists(path):
|
||||
print(f"Dataset does not exist: {path}")
|
||||
print(
|
||||
"Path should be a basename that both .idx and .bin can be appended to get full filenames."
|
||||
)
|
||||
return None
|
||||
if impl == "infer":
|
||||
impl = infer_dataset_impl(path)
|
||||
if impl == "lazy" and IndexedDataset.exists(path):
|
||||
return IndexedDataset(path)
|
||||
elif impl == "cached" and IndexedDataset.exists(path):
|
||||
return IndexedCachedDataset(path)
|
||||
elif impl == "mmap" and MMapIndexedDataset.exists(path):
|
||||
return MMapIndexedDataset(path, skip_warmup)
|
||||
print(f"Unknown dataset implementation: {impl}")
|
||||
return None
|
||||
|
||||
|
||||
def dataset_exists(path, impl):
|
||||
if impl == "mmap":
|
||||
return MMapIndexedDataset.exists(path)
|
||||
else:
|
||||
return IndexedDataset.exists(path)
|
||||
|
||||
|
||||
def read_longs(f, n):
|
||||
a = np.empty(n, dtype=np.int64)
|
||||
f.readinto(a)
|
||||
return a
|
||||
|
||||
|
||||
def write_longs(f, a):
|
||||
f.write(np.array(a, dtype=np.int64))
|
||||
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: np.float32,
|
||||
7: np.float64,
|
||||
8: np.uint16,
|
||||
}
|
||||
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
|
||||
def create_doc_idx(sizes):
|
||||
doc_idx = [0]
|
||||
for i, s in enumerate(sizes):
|
||||
if s == 0:
|
||||
doc_idx.append(i + 1)
|
||||
return doc_idx
|
||||
|
||||
|
||||
class IndexedDataset(torch.utils.data.Dataset):
|
||||
"""Loader for IndexedDataset"""
|
||||
|
||||
_HDR_MAGIC = b"TNTIDX\x00\x00"
|
||||
|
||||
def __init__(self, path):
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.data_file = None
|
||||
self.read_index(path)
|
||||
|
||||
def read_index(self, path):
|
||||
with open(index_file_path(path), "rb") as f:
|
||||
magic = f.read(8)
|
||||
assert magic == self._HDR_MAGIC, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
version = f.read(8)
|
||||
assert struct.unpack("<Q", version) == (1,)
|
||||
code, self.element_size = struct.unpack("<QQ", f.read(16))
|
||||
self.dtype = dtypes[code]
|
||||
self._len, self.s = struct.unpack("<QQ", f.read(16))
|
||||
self.doc_count = struct.unpack("<Q", f.read(8))
|
||||
self.dim_offsets = read_longs(f, self._len + 1)
|
||||
self.data_offsets = read_longs(f, self._len + 1)
|
||||
self.sizes = read_longs(f, self.s)
|
||||
self.doc_idx = read_longs(f, self.doc_count)
|
||||
|
||||
def read_data(self, path):
|
||||
self.data_file = open(data_file_path(path), "rb", buffering=0)
|
||||
|
||||
def check_index(self, i):
|
||||
if i < 0 or i >= self._len:
|
||||
raise IndexError("index out of range")
|
||||
|
||||
def __del__(self):
|
||||
if self.data_file:
|
||||
self.data_file.close()
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]]
|
||||
size = sum(sizes)
|
||||
a = np.empty(size, dtype=self.dtype)
|
||||
self.data_file.seek(self.data_offsets[start] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
offsets = list(accumulate(sizes))
|
||||
sents = np.split(a, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.sizes[index]
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False # avoid prefetching to save memory
|
||||
|
||||
|
||||
class IndexedCachedDataset(IndexedDataset):
|
||||
def __init__(self, path):
|
||||
super().__init__(path)
|
||||
self.cache = None
|
||||
self.cache_index = {}
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return True
|
||||
|
||||
def prefetch(self, indices):
|
||||
if all(i in self.cache_index for i in indices):
|
||||
return
|
||||
if not self.data_file:
|
||||
self.read_data(self.path)
|
||||
indices = sorted(set(indices))
|
||||
total_size = 0
|
||||
for i in indices:
|
||||
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
self.cache = np.empty(total_size, dtype=self.dtype)
|
||||
ptx = 0
|
||||
self.cache_index.clear()
|
||||
for i in indices:
|
||||
self.cache_index[i] = ptx
|
||||
size = self.data_offsets[i + 1] - self.data_offsets[i]
|
||||
a = self.cache[ptx : ptx + size]
|
||||
self.data_file.seek(self.data_offsets[i] * self.element_size)
|
||||
self.data_file.readinto(a)
|
||||
ptx += size
|
||||
if self.data_file:
|
||||
# close and delete data file after prefetch so we can pickle
|
||||
self.data_file.close()
|
||||
self.data_file = None
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
i = idx
|
||||
self.check_index(i)
|
||||
tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]]
|
||||
a = np.empty(tensor_size, dtype=self.dtype)
|
||||
ptx = self.cache_index[i]
|
||||
np.copyto(a, self.cache[ptx : ptx + a.size])
|
||||
return a
|
||||
elif isinstance(idx, slice):
|
||||
# Hack just to make this work, can optimizer later if necessary
|
||||
sents = []
|
||||
for i in range(*idx.indices(len(self))):
|
||||
sents.append(self[i])
|
||||
return sents
|
||||
|
||||
|
||||
class IndexedDatasetBuilder(object):
|
||||
element_sizes = {
|
||||
np.uint8: 1,
|
||||
np.int8: 1,
|
||||
np.int16: 2,
|
||||
np.int32: 4,
|
||||
np.int64: 8,
|
||||
np.float32: 4,
|
||||
np.float64: 8,
|
||||
}
|
||||
|
||||
def __init__(self, out_file, dtype=np.int32):
|
||||
self.out_file = open(out_file, "wb")
|
||||
self.dtype = dtype
|
||||
self.data_offsets = [0]
|
||||
self.dim_offsets = [0]
|
||||
self.sizes = []
|
||||
self.element_size = self.element_sizes[self.dtype]
|
||||
self.doc_idx = [0]
|
||||
|
||||
def add_item(self, np_array):
|
||||
assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype
|
||||
bytes = self.out_file.write(np_array)
|
||||
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
|
||||
for s in np_array.shape:
|
||||
self.sizes.append(s)
|
||||
self.dim_offsets.append(self.dim_offsets[-1] + len(np_array.shape))
|
||||
|
||||
def end_document(self):
|
||||
self.doc_idx.append(len(self.sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
index = IndexedDataset(another_file)
|
||||
assert index.dtype == self.dtype
|
||||
|
||||
begin = self.data_offsets[-1]
|
||||
for offset in index.data_offsets[1:]:
|
||||
self.data_offsets.append(begin + offset)
|
||||
self.sizes.extend(index.sizes)
|
||||
begin = self.dim_offsets[-1]
|
||||
for dim_offset in index.dim_offsets[1:]:
|
||||
self.dim_offsets.append(begin + dim_offset)
|
||||
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024)
|
||||
if data:
|
||||
self.out_file.write(data)
|
||||
else:
|
||||
break
|
||||
|
||||
def finalize(self, index_file):
|
||||
self.out_file.close()
|
||||
index = open(index_file, "wb")
|
||||
index.write(b"TNTIDX\x00\x00")
|
||||
index.write(struct.pack("<Q", 1))
|
||||
index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
|
||||
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
|
||||
index.write(struct.pack("<Q", len(self.doc_idx)))
|
||||
write_longs(index, self.dim_offsets)
|
||||
write_longs(index, self.data_offsets)
|
||||
write_longs(index, self.sizes)
|
||||
write_longs(index, self.doc_idx)
|
||||
index.close()
|
||||
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
with open(path, "rb") as stream:
|
||||
while stream.read(100 * 1024 * 1024):
|
||||
pass
|
||||
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer(object):
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
# Write Magic string so we can check the file format then opening it again.
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
# Write version number
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
# Little endian unsigned 8 Bit integer
|
||||
self._file.write(struct.pack("<B", code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
pointers = np.zeros(len(sizes), dtype=np.int64)
|
||||
sizes = np.array(sizes, dtype=np.int64)
|
||||
|
||||
np.cumsum(sizes[:-1], out=pointers[1:])
|
||||
pointers = pointers * dtype().itemsize
|
||||
return pointers
|
||||
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order="C"))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
# Little endian unsigned 64 Bit integer
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
# Little endian unsigned 8 Bit integer
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
if not skip_warmup:
|
||||
print(" warming up index mmap file...")
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
print(" reading sizes...")
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
print(" reading pointers...")
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
print(" reading document index...")
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path, skip_warmup)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, path, skip_warmup):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||
|
||||
if not skip_warmup:
|
||||
print(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(
|
||||
data_file_path(self._path), mode="r", order="C"
|
||||
)
|
||||
print(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError("Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
total_size = sum(sizes)
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
||||
)
|
||||
sents = np.split(np_array, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
"""Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
"""
|
||||
ptr, size = self._index[idx]
|
||||
if length is None:
|
||||
length = size - offset
|
||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._index.doc_idx
|
||||
|
||||
def get_doc_idx(self):
|
||||
return self._index._doc_idx
|
||||
|
||||
def set_doc_idx(self, doc_idx_):
|
||||
self._index._doc_idx = doc_idx_
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
||||
|
||||
|
||||
class MMapIndexedDatasetBuilder(object):
|
||||
def __init__(self, out_file, dtype=np.int64):
|
||||
self._data_file = open(out_file, "wb")
|
||||
self._dtype = dtype
|
||||
self._sizes = []
|
||||
self._doc_idx = [0]
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def add_item(self, np_array):
|
||||
assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype
|
||||
self._data_file.write(np_array.tobytes(order="C"))
|
||||
self._sizes.append(np_array.size)
|
||||
|
||||
def end_document(self):
|
||||
self._doc_idx.append(len(self._sizes))
|
||||
|
||||
def merge_file_(self, another_file):
|
||||
# Concatenate index
|
||||
index = MMapIndexedDataset.Index(index_file_path(another_file))
|
||||
assert index.dtype == self._dtype
|
||||
|
||||
for size in index.sizes:
|
||||
self._sizes.append(size)
|
||||
|
||||
# Concatenate data
|
||||
with open(data_file_path(another_file), "rb") as f:
|
||||
shutil.copyfileobj(f, self._data_file)
|
||||
|
||||
def finalize(self, index_file):
|
||||
self._data_file.close()
|
||||
|
||||
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
|
||||
index.write(self._sizes, self._doc_idx)
|
243
finetune/json2binidx_tool/tools/preprocess_data.py
vendored
Normal file
243
finetune/json2binidx_tool/tools/preprocess_data.py
vendored
Normal file
@ -0,0 +1,243 @@
|
||||
# Copyright (c) 2021, EleutherAI
|
||||
# This file is based on code by the authors denoted below and has been modified from its original version.
|
||||
#
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Processing data for pretraining."""
|
||||
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
|
||||
import lm_dataformat as lmd
|
||||
import numpy as np
|
||||
|
||||
sys.path.append(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
|
||||
)
|
||||
import time
|
||||
import tqdm
|
||||
import ftfy
|
||||
|
||||
from tokenizer import build_tokenizer
|
||||
import indexed_dataset
|
||||
from threading import Semaphore
|
||||
|
||||
|
||||
class Encoder(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def initializer(self):
|
||||
# Use Encoder class as a container for global data
|
||||
Encoder.tokenizer = build_tokenizer(self.args)
|
||||
|
||||
def encode(self, text):
|
||||
if self.args.ftfy:
|
||||
text = ftfy.fix_text(text)
|
||||
ids = {}
|
||||
for key in self.args.jsonl_keys:
|
||||
doc_ids = []
|
||||
text_ids = Encoder.tokenizer.tokenize(text)
|
||||
if len(text_ids) > 0:
|
||||
doc_ids.append(text_ids)
|
||||
if self.args.append_eod:
|
||||
doc_ids[-1].append(Encoder.tokenizer.eod)
|
||||
ids[key] = doc_ids
|
||||
return ids, len(text)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_argument_group(title="input data")
|
||||
group.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to input jsonl files or lmd archive(s) - if using multiple archives, put them in a comma separated "
|
||||
"list",
|
||||
)
|
||||
group.add_argument(
|
||||
"--jsonl-keys",
|
||||
nargs="+",
|
||||
default=["text"],
|
||||
help="space separate listed of keys to extract from jsonl. Defa",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-docs",
|
||||
default=None,
|
||||
help="Optional: Number of documents in the input data (if known) for an accurate progress bar.",
|
||||
type=int,
|
||||
)
|
||||
group = parser.add_argument_group(title="tokenizer")
|
||||
group.add_argument(
|
||||
"--tokenizer-type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=[
|
||||
"HFGPT2Tokenizer",
|
||||
"HFTokenizer",
|
||||
"GPT2BPETokenizer",
|
||||
"CharLevelTokenizer",
|
||||
"TiktokenTokenizer",
|
||||
"RWKVTokenizer",
|
||||
],
|
||||
help="What type of tokenizer to use.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--vocab-file", type=str, default=None, help="Path to the vocab file"
|
||||
)
|
||||
group.add_argument(
|
||||
"--merge-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the BPE merge file (if necessary).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--append-eod",
|
||||
action="store_true",
|
||||
help="Append an <eod> token to the end of a document.",
|
||||
)
|
||||
group.add_argument("--ftfy", action="store_true", help="Use ftfy to clean text")
|
||||
group = parser.add_argument_group(title="output data")
|
||||
group.add_argument(
|
||||
"--output-prefix",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to binary output file without suffix",
|
||||
)
|
||||
group.add_argument(
|
||||
"--dataset-impl",
|
||||
type=str,
|
||||
default="mmap",
|
||||
choices=["lazy", "cached", "mmap"],
|
||||
help="Dataset implementation to use. Default: mmap",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group(title="runtime")
|
||||
group.add_argument(
|
||||
"--workers", type=int, default=1, help="Number of worker processes to launch"
|
||||
)
|
||||
group.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Interval between progress updates",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args.keep_empty = False
|
||||
|
||||
# some default/dummy values for the tokenizer
|
||||
args.rank = 0
|
||||
args.make_vocab_size_divisible_by = 128
|
||||
args.model_parallel_size = 1
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def yield_from_files(fnames: list, semaphore):
|
||||
"""
|
||||
Iterator over input documents using lm_dataformat. Should be able to handle jsons / texts /
|
||||
other compressed formats. Also filters out empty documents.
|
||||
|
||||
:param fnames: list of filenames
|
||||
"""
|
||||
|
||||
def yielder(fname, semaphore):
|
||||
for f in filter(lambda x: x, lmd.Reader(fname).stream_data()):
|
||||
semaphore.acquire()
|
||||
yield f
|
||||
|
||||
for fname in fnames:
|
||||
semaphore.acquire()
|
||||
|
||||
yield from yielder(fname, semaphore)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
encoder = Encoder(args)
|
||||
tokenizer = build_tokenizer(args)
|
||||
print(f"Vocab size: {tokenizer.vocab_size}")
|
||||
print(f"Output prefix: {args.output_prefix}")
|
||||
|
||||
# build a semaphore object to stop `yield_from_files` from getting ahead of encoder.encode and
|
||||
# hence building up memory
|
||||
semaphore = Semaphore(10000 + args.workers)
|
||||
|
||||
# use multiprocessing to iterate over input documents
|
||||
fin = yield_from_files(args.input.split(","), semaphore)
|
||||
|
||||
if args.workers > 1:
|
||||
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
|
||||
encoded_docs = pool.imap(encoder.encode, fin, chunksize=25)
|
||||
else:
|
||||
encoder.initializer()
|
||||
encoded_docs = (encoder.encode(doc) for doc in fin)
|
||||
|
||||
# make a dataset builder for each key in args.jsonl_keys
|
||||
# each key will output to a different file beginning with args.output_prefix
|
||||
output_bin_files = {}
|
||||
output_idx_files = {}
|
||||
builders = {}
|
||||
for key in args.jsonl_keys:
|
||||
output_bin_files[key] = "{}_{}_{}.bin".format(
|
||||
args.output_prefix, key, "document"
|
||||
)
|
||||
output_idx_files[key] = "{}_{}_{}.idx".format(
|
||||
args.output_prefix, key, "document"
|
||||
)
|
||||
builders[key] = indexed_dataset.make_builder(
|
||||
output_bin_files[key],
|
||||
impl=args.dataset_impl,
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
)
|
||||
|
||||
# actually do tokenization
|
||||
proc_start = time.time()
|
||||
total_bytes_processed = 0
|
||||
pbar = tqdm.tqdm()
|
||||
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
|
||||
total_bytes_processed += bytes_processed
|
||||
|
||||
# release semaphore so `yield_from_files` can add another file to the buffer
|
||||
semaphore.release()
|
||||
|
||||
# add each tokenized document / sentence
|
||||
for key, sentences in doc.items():
|
||||
for sentence in sentences:
|
||||
builders[key].add_item(np.array(sentence, dtype=builders[key].dtype))
|
||||
# separate with eos token
|
||||
builders[key].end_document()
|
||||
|
||||
# log progress
|
||||
if i % args.log_interval == 0:
|
||||
current = time.time()
|
||||
elapsed = current - proc_start
|
||||
mbs = total_bytes_processed / elapsed / 1024 / 1024
|
||||
pbar.set_description(
|
||||
f"Processed {i}{'' if args.num_docs is None else '/' + str(args.num_docs)} documents ({i / elapsed:0.2f} docs/s, {mbs:0.2f} MB/s)."
|
||||
)
|
||||
if i != 0:
|
||||
pbar.update(args.log_interval)
|
||||
|
||||
# save output file
|
||||
for key in args.jsonl_keys:
|
||||
builders[key].finalize(output_idx_files[key])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
232
finetune/json2binidx_tool/tools/rwkv_tokenizer.py
vendored
Normal file
232
finetune/json2binidx_tool/tools/rwkv_tokenizer.py
vendored
Normal file
@ -0,0 +1,232 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
# Source: https://github.com/BlinkDL/ChatRWKV/blob/main/tokenizer/rwkv_tokenizer.py
|
||||
########################################################################################################
|
||||
|
||||
import os, sys, time, random
|
||||
|
||||
print('''
|
||||
#######################################################################################################################
|
||||
|
||||
This tokenizer is not used in any RWKV models yet. I plan to use it for the future multilang RWKV models.
|
||||
|
||||
Benefits:
|
||||
|
||||
* Good support of most languages, from European to CJK to Arabic and Hindi and more.
|
||||
|
||||
* Clean vocab. Good for code too. Vocab size = 65525 (use 0 for <|endoftext|>).
|
||||
|
||||
* Good at numbers: the numerical tokens are '0'~'9', '10'~'99', ' 0'~' 9', ' 10'~' 99'.
|
||||
|
||||
* Very easy tokenization:
|
||||
|
||||
** The input text must be in UTF-8.
|
||||
|
||||
** Greedy encoding: always pick the longest (in bytes) token (with the highest id) that matches your UTF-8 bytes.
|
||||
|
||||
* The tokenization result is surprisingly good, because the vocab respects word boundaries and UTF-8 boundaries.
|
||||
|
||||
For 10x faster speed:
|
||||
mypyc rwkv_tokenizer.py
|
||||
python3 -c "import rwkv_tokenizer"
|
||||
|
||||
#######################################################################################################################
|
||||
''')
|
||||
|
||||
########################################################################################################
|
||||
# Tokenizer #1 (reference, naive, slow)
|
||||
########################################################################################################
|
||||
|
||||
class RWKV_TOKENIZER():
|
||||
table = None # : list[list[list[bytes]]] = None
|
||||
good = None # : list[set[int]]
|
||||
wlen = None # : list[int]
|
||||
def __init__(self, file_name):
|
||||
self.vocab_size = 65525
|
||||
self.idx2token = {}
|
||||
sorted = [] # must be already sorted
|
||||
lines = open(file_name, "r", encoding="utf-8").readlines()
|
||||
for l in lines:
|
||||
idx = int(l[:l.index(' ')])
|
||||
x = eval(l[l.index(' '):l.rindex(' ')])
|
||||
x = x.encode("utf-8") if isinstance(x, str) else x
|
||||
assert isinstance(x, bytes)
|
||||
assert len(x) == int(l[l.rindex(' '):])
|
||||
sorted += [x]
|
||||
self.idx2token[idx] = x
|
||||
|
||||
self.token2idx = {}
|
||||
for k, v in self.idx2token.items():
|
||||
self.token2idx[v] = int(k)
|
||||
|
||||
# precompute some tables for fast matching
|
||||
self.table = [[[] for j in range(256)] for i in range(256)]
|
||||
self.good = [set() for i in range(256)]
|
||||
self.wlen = [0 for i in range(256)]
|
||||
|
||||
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
|
||||
s = sorted[i]
|
||||
if len(s) >= 2:
|
||||
s0 = int(s[0])
|
||||
s1 = int(s[1])
|
||||
self.table[s0][s1] += [s]
|
||||
self.wlen[s0] = max(self.wlen[s0], len(s))
|
||||
self.good[s0].add(s1)
|
||||
|
||||
def encodeBytes(self, src: bytes):
|
||||
src_len: int = len(src)
|
||||
tokens = []
|
||||
i: int = 0
|
||||
while i < src_len:
|
||||
s: bytes = src[i : i + 1]
|
||||
|
||||
if i < src_len - 1:
|
||||
s1: int = int(src[i + 1])
|
||||
s0: int = int(src[i])
|
||||
if s1 in self.good[s0]:
|
||||
sss: bytes = src[i : i + self.wlen[s0]]
|
||||
try:
|
||||
s = next(filter(sss.startswith, self.table[s0][s1]))
|
||||
except:
|
||||
pass
|
||||
tokens.append(self.token2idx[s])
|
||||
i += len(s)
|
||||
|
||||
return tokens
|
||||
|
||||
def decodeBytes(self, tokens):
|
||||
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
||||
|
||||
def encode(self, src: str):
|
||||
return self.encodeBytes(src.encode("utf-8"))
|
||||
|
||||
def decode(self, tokens):
|
||||
return self.decodeBytes(tokens).decode('utf-8')
|
||||
|
||||
def token_to_id(self, token):
|
||||
return self.token2idx[token]
|
||||
|
||||
def get_vocab_size(self):
|
||||
return self.vocab_size
|
||||
|
||||
def get_vocab(self):
|
||||
return self.idx2token
|
||||
|
||||
def printTokens(self, tokens):
|
||||
for i in tokens:
|
||||
s = self.idx2token[i]
|
||||
try:
|
||||
s = s.decode('utf-8')
|
||||
except:
|
||||
pass
|
||||
print(f'{repr(s)}{i}', end=' ')
|
||||
# print(repr(s), i)
|
||||
print()
|
||||
|
||||
########################################################################################################
|
||||
# Tokenizer #2 (trie, faster) https://github.com/TkskKurumi/ChatRWKV-TRIE-Tokenizer
|
||||
########################################################################################################
|
||||
|
||||
class TRIE:
|
||||
__slots__ = tuple("ch,to,values,front".split(","))
|
||||
to:list
|
||||
values:set
|
||||
def __init__(self, front=None, ch=None):
|
||||
self.ch = ch
|
||||
self.to = [None for ch in range(256)]
|
||||
self.values = set()
|
||||
self.front = front
|
||||
|
||||
def __repr__(self):
|
||||
fr = self
|
||||
ret = []
|
||||
while(fr!=None):
|
||||
if(fr.ch!=None):
|
||||
ret.append(fr.ch)
|
||||
fr = fr.front
|
||||
return "<TRIE %s %s>"%(ret[::-1], self.values)
|
||||
|
||||
def add(self, key:bytes, idx:int=0, val=None):
|
||||
if(idx == len(key)):
|
||||
if(val is None):
|
||||
val = key
|
||||
self.values.add(val)
|
||||
return self
|
||||
ch = key[idx]
|
||||
if(self.to[ch] is None):
|
||||
self.to[ch] = TRIE(front=self, ch=ch)
|
||||
return self.to[ch].add(key, idx=idx+1, val=val)
|
||||
|
||||
def find_longest(self, key:bytes, idx:int=0):
|
||||
u:TRIE = self
|
||||
ch:int = key[idx]
|
||||
|
||||
while(u.to[ch] is not None):
|
||||
u = u.to[ch]
|
||||
idx += 1
|
||||
if(u.values):
|
||||
ret = idx, u, u.values
|
||||
if(idx==len(key)):
|
||||
break
|
||||
ch = key[idx]
|
||||
return ret
|
||||
|
||||
class TRIE_TOKENIZER():
|
||||
def __init__(self, file_name):
|
||||
self.vocab_size = 65525
|
||||
self.idx2token = {}
|
||||
sorted = [] # must be already sorted
|
||||
with open(file_name, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for l in lines:
|
||||
idx = int(l[:l.index(' ')])
|
||||
x = eval(l[l.index(' '):l.rindex(' ')])
|
||||
x = x.encode("utf-8") if isinstance(x, str) else x
|
||||
assert isinstance(x, bytes)
|
||||
assert len(x) == int(l[l.rindex(' '):])
|
||||
sorted += [x]
|
||||
self.idx2token[idx] = x
|
||||
|
||||
self.token2idx = {}
|
||||
for k,v in self.idx2token.items():
|
||||
self.token2idx[v] = int(k)
|
||||
|
||||
self.root = TRIE()
|
||||
for t, i in self.token2idx.items():
|
||||
_ = self.root.add(t, val=(t, i))
|
||||
|
||||
def encodeBytes(self, src:bytes):
|
||||
idx:int = 0
|
||||
tokens = []
|
||||
while (idx < len(src)):
|
||||
_idx:int = idx
|
||||
idx, _, values = self.root.find_longest(src, idx)
|
||||
assert(idx != _idx)
|
||||
_, token = next(iter(values))
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def decodeBytes(self, tokens):
|
||||
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
||||
|
||||
def encode(self, src):
|
||||
return self.encodeBytes(src.encode("utf-8"))
|
||||
|
||||
def decode(self, tokens):
|
||||
return self.decodeBytes(tokens).decode('utf-8')
|
||||
|
||||
def get_vocab_size(self):
|
||||
return self.vocab_size
|
||||
|
||||
def get_vocab(self):
|
||||
return self.idx2token
|
||||
|
||||
def printTokens(self, tokens):
|
||||
for i in tokens:
|
||||
s = self.idx2token[i]
|
||||
try:
|
||||
s = s.decode('utf-8')
|
||||
except:
|
||||
pass
|
||||
print(f'{repr(s)}{i}', end=' ')
|
||||
print()
|
205
finetune/json2binidx_tool/tools/tokenizer.py
vendored
Normal file
205
finetune/json2binidx_tool/tools/tokenizer.py
vendored
Normal file
@ -0,0 +1,205 @@
|
||||
# Copyright (c) 2021, EleutherAI
|
||||
# This file is based on code by the authors denoted below and has been modified from its original version.
|
||||
#
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Megatron tokenizers."""
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
from tokenizers import Tokenizer
|
||||
from rwkv_tokenizer import RWKV_TOKENIZER, TRIE_TOKENIZER
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def build_tokenizer(args):
|
||||
"""Initialize tokenizer."""
|
||||
if args.rank == 0:
|
||||
print("> building {} tokenizer ...".format(args.tokenizer_type), flush=True)
|
||||
|
||||
# Select and instantiate the tokenizer.
|
||||
|
||||
if args.tokenizer_type.lower() == "HFTokenizer".lower():
|
||||
assert args.vocab_file is not None
|
||||
tokenizer = HFTokenizer(args.vocab_file)
|
||||
elif args.tokenizer_type.lower() == "RWKVTokenizer".lower():
|
||||
assert args.vocab_file is not None
|
||||
tokenizer = RWKVTokenizer(args.vocab_file)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"{} tokenizer is not " "implemented.".format(args.tokenizer_type)
|
||||
)
|
||||
|
||||
# Add vocab size.
|
||||
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _vocab_size_with_padding(orig_vocab_size, args):
|
||||
"""Pad vocab size so it is divisible by model parallel size and
|
||||
still having GPU friendly size."""
|
||||
|
||||
after = orig_vocab_size
|
||||
multiple = args.make_vocab_size_divisible_by * args.model_parallel_size
|
||||
while (after % multiple) != 0:
|
||||
after += 1
|
||||
if args.rank == 0:
|
||||
print(
|
||||
" > padded vocab (size: {}) with {} dummy tokens "
|
||||
"(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after),
|
||||
flush=True,
|
||||
)
|
||||
return after
|
||||
|
||||
|
||||
class AbstractTokenizer(ABC):
|
||||
"""Abstract class for tokenizer."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab_size(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab(self):
|
||||
"""Dictionary from vocab text token to id token."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def inv_vocab(self):
|
||||
"""Dictionary from vocab id token to text token."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text):
|
||||
pass
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
raise NotImplementedError(
|
||||
"detokenizer is not implemented for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
raise NotImplementedError(
|
||||
"CLS is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def sep(self):
|
||||
raise NotImplementedError(
|
||||
"SEP is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
raise NotImplementedError(
|
||||
"PAD is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
raise NotImplementedError(
|
||||
"EOD is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
@property
|
||||
def mask(self):
|
||||
raise NotImplementedError(
|
||||
"MASK is not provided for {} " "tokenizer".format(self.name)
|
||||
)
|
||||
|
||||
|
||||
class HFTokenizer(AbstractTokenizer):
|
||||
"""Designed to Integrate HF's Tokenizer library."""
|
||||
|
||||
def __init__(self, vocab_file):
|
||||
name = "HFTokenizer"
|
||||
super().__init__(name)
|
||||
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
self.eod_id = self.tokenizer.token_to_id("<|endoftext|>")
|
||||
self.pad_id = self.tokenizer.token_to_id("<|padding|>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.get_vocab_size()
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self.tokenizer.get_vocab()
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return self.tokenizer.decoder
|
||||
|
||||
def tokenize(self, text: str):
|
||||
return self.tokenizer.encode(text).ids
|
||||
|
||||
def tokenize_batch(self, text_batch: Union[List[str], str]):
|
||||
return self.tokenizer.encode_batch(text_batch)
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
return self.tokenizer.decode(token_ids)
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
return self.eod_id
|
||||
|
||||
|
||||
class RWKVTokenizer(AbstractTokenizer):
|
||||
"""RWKV Worlds Tokenizer."""
|
||||
|
||||
def __init__(self, vocab_file='rwkv_vocab_v20230424.txt'):
|
||||
name = "RWKVTokenizer"
|
||||
super().__init__(name)
|
||||
|
||||
self.tokenizer = TRIE_TOKENIZER(vocab_file)
|
||||
self.eod_id = 0 # self.tokenizer.token_to_id("<|endoftext|>")
|
||||
# self.pad_id = self.tokenizer.token_to_id("<|padding|>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.get_vocab_size()
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self.tokenizer.get_vocab()
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return self.tokenizer.decode
|
||||
|
||||
def tokenize(self, text: str):
|
||||
return self.tokenizer.encode(text)
|
||||
|
||||
def tokenize_batch(self, text_batch: Union[List[str], str]):
|
||||
return self.tokenizer.encode_batch(text_batch)
|
||||
|
||||
def detokenize(self, token_ids):
|
||||
return self.tokenizer.decode(token_ids)
|
||||
|
||||
@property
|
||||
def eod(self):
|
||||
return self.eod_id
|
133
finetune/lora/cuda/wkv_cuda.cu
vendored
Normal file
133
finetune/lora/cuda/wkv_cuda.cu
vendored
Normal file
@ -0,0 +1,133 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
|
||||
#define MIN_VALUE (-1e38)
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||
F *__restrict__ const _y) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
F *__restrict__ const y = _y + _offset;
|
||||
|
||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||
F aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
|
||||
F ww = u + kk;
|
||||
F p = max(pp, ww);
|
||||
F e1 = exp(pp - p);
|
||||
F e2 = exp(ww - p);
|
||||
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||
const F *__restrict__ const _y, const F *__restrict__ const _gy,
|
||||
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
const F *__restrict__ const y = _y + _offset;
|
||||
const F *__restrict__ const gy = _gy + _offset;
|
||||
F *__restrict__ const gk = _gk + _offset;
|
||||
F *__restrict__ const gv = _gv + _offset;
|
||||
|
||||
F q[Tmax], r[Tmax];
|
||||
|
||||
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
const F yy = y[ii];
|
||||
|
||||
F ww = u + kk;
|
||||
F p = max(pp, ww);
|
||||
F e1 = exp(pp - p);
|
||||
F e2 = exp(ww - p);
|
||||
const F qq = gy[ii] / (e1 * bb + e2);
|
||||
gw += (ga - gb * yy) * e1 * qq;
|
||||
gu += (vv - yy) * e2 * qq;
|
||||
q[i] = qq;
|
||||
r[i] = ww - p;
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
ga = e1 * (aa + ga);
|
||||
gb = e1 * (bb + gb);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
const int _offsetBC = _b * C + _c;
|
||||
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
|
||||
_gu[_offsetBC] = gu;
|
||||
|
||||
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = T - 1; i >= 0; i--) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
const F yy = y[ii];
|
||||
const F qq = q[i];
|
||||
const F rr = r[i];
|
||||
|
||||
F e1 = qq * exp(rr);
|
||||
F e2 = exp(kk + pp);
|
||||
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
|
||||
gv[ii] = e1 + e2 * aa;
|
||||
|
||||
const F ww = w + pp;
|
||||
const F www = rr - u - kk;
|
||||
const F p = max(ww, www);
|
||||
e1 = exp(ww - p);
|
||||
e2 = qq * exp(www - p);
|
||||
aa = e1 * aa + e2;
|
||||
bb = e1 * bb - e2 * yy;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||
}
|
||||
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||
}
|
132
finetune/lora/cuda/wkv_cuda_bf16.cu
vendored
Normal file
132
finetune/lora/cuda/wkv_cuda_bf16.cu
vendored
Normal file
@ -0,0 +1,132 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
#include "ATen/ATen.h"
|
||||
#define MIN_VALUE (-1e38)
|
||||
typedef at::BFloat16 bf16;
|
||||
|
||||
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
||||
bf16 *__restrict__ const _y) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
float u = float(_u[_c]);
|
||||
float w = _w[_c];
|
||||
const bf16 *__restrict__ const k = _k + _offset;
|
||||
const bf16 *__restrict__ const v = _v + _offset;
|
||||
bf16 *__restrict__ const y = _y + _offset;
|
||||
|
||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||
float aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
|
||||
const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
|
||||
bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
float u = float(_u[_c]);
|
||||
float w = _w[_c];
|
||||
const bf16 *__restrict__ const k = _k + _offset;
|
||||
const bf16 *__restrict__ const v = _v + _offset;
|
||||
const bf16 *__restrict__ const y = _y + _offset;
|
||||
const bf16 *__restrict__ const gy = _gy + _offset;
|
||||
bf16 *__restrict__ const gk = _gk + _offset;
|
||||
bf16 *__restrict__ const gv = _gv + _offset;
|
||||
|
||||
float q[Tmax], r[Tmax];
|
||||
|
||||
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
const float yy = float(y[ii]);
|
||||
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
const float qq = float(gy[ii]) / (e1 * bb + e2);
|
||||
gw += (ga - gb * yy) * e1 * qq;
|
||||
gu += (vv - yy) * e2 * qq;
|
||||
q[i] = qq;
|
||||
r[i] = ww - p;
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
ga = e1 * (aa + ga);
|
||||
gb = e1 * (bb + gb);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
const int _offsetBC = _b * C + _c;
|
||||
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
|
||||
_gu[_offsetBC] = bf16(gu);
|
||||
|
||||
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = T - 1; i >= 0; i--) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
const float yy = float(y[ii]);
|
||||
const float qq = q[i];
|
||||
const float rr = r[i];
|
||||
|
||||
float e1 = qq * exp(rr);
|
||||
float e2 = exp(kk + pp);
|
||||
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
|
||||
gv[ii] = bf16(e1 + e2 * aa);
|
||||
|
||||
const float ww = w + pp;
|
||||
const float www = rr - u - kk;
|
||||
const float p = max(ww, www);
|
||||
e1 = exp(ww - p);
|
||||
e2 = qq * exp(www - p);
|
||||
aa = e1 * aa + e2;
|
||||
bb = e1 * bb - e2 * yy;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||
}
|
||||
|
||||
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||
}
|
21
finetune/lora/cuda/wkv_op.cpp
vendored
Normal file
21
finetune/lora/cuda/wkv_op.cpp
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
|
||||
|
||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
||||
}
|
||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "wkv forward");
|
||||
m.def("backward", &backward, "wkv backward");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(wkv, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("backward", backward);
|
||||
}
|
25
finetune/lora/cuda/wkv_op_bf16.cpp
vendored
Normal file
25
finetune/lora/cuda/wkv_op_bf16.cpp
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/ATen.h"
|
||||
typedef at::BFloat16 bf16;
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
|
||||
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
|
||||
|
||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||
}
|
||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
|
||||
torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
|
||||
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "wkv forward");
|
||||
m.def("backward", &backward, "wkv backward");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(wkv, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("backward", backward);
|
||||
}
|
53
finetune/lora/merge_lora.py
vendored
Normal file
53
finetune/lora/merge_lora.py
vendored
Normal file
@ -0,0 +1,53 @@
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
import typing
|
||||
import torch
|
||||
|
||||
if '-h' in sys.argv or '--help' in sys.argv:
|
||||
print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>')
|
||||
|
||||
if sys.argv[1] == '--use-gpu':
|
||||
device = 'cuda'
|
||||
lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5]
|
||||
else:
|
||||
device = 'cpu'
|
||||
lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4]
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
|
||||
# merge LoRA-only slim checkpoint into the main weights
|
||||
w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
|
||||
for k in w_lora.keys():
|
||||
w[k] = w_lora[k]
|
||||
output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
|
||||
# merge LoRA weights
|
||||
keys = list(w.keys())
|
||||
for k in keys:
|
||||
if k.endswith('.weight'):
|
||||
prefix = k[:-len('.weight')]
|
||||
lora_A = prefix + '.lora_A'
|
||||
lora_B = prefix + '.lora_B'
|
||||
if lora_A in keys:
|
||||
assert lora_B in keys
|
||||
print(f'merging {lora_A} and {lora_B} into {k}')
|
||||
assert w[lora_B].shape[1] == w[lora_A].shape[0]
|
||||
lora_r = w[lora_B].shape[1]
|
||||
w[k] = w[k].to(device=device)
|
||||
w[lora_A] = w[lora_A].to(device=device)
|
||||
w[lora_B] = w[lora_B].to(device=device)
|
||||
w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
|
||||
output_w[k] = w[k].to(device='cpu', copy=True)
|
||||
del w[k]
|
||||
del w[lora_A]
|
||||
del w[lora_B]
|
||||
continue
|
||||
|
||||
if 'lora' not in k:
|
||||
print(f'retaining {k}')
|
||||
output_w[k] = w[k].clone()
|
||||
del w[k]
|
||||
|
||||
torch.save(output_w, output)
|
0
finetune/lora/src/__init__.py
vendored
Normal file
0
finetune/lora/src/__init__.py
vendored
Normal file
269
finetune/lora/src/binidx.py
vendored
Normal file
269
finetune/lora/src/binidx.py
vendored
Normal file
@ -0,0 +1,269 @@
|
||||
from lib2to3.pgen2 import token
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import shutil
|
||||
import struct
|
||||
from functools import lru_cache
|
||||
from itertools import accumulate
|
||||
|
||||
def print_rank_0(*message):
|
||||
pass
|
||||
# """If distributed is initialized print only on rank 0."""
|
||||
# if torch.distributed.is_initialized():
|
||||
# if torch.distributed.get_rank() == 0:
|
||||
# print(*message, flush=True)
|
||||
# else:
|
||||
# print(*message, flush=True)
|
||||
|
||||
def _warmup_mmap_file(path):
|
||||
pass
|
||||
# with open(path, "rb") as stream:
|
||||
# while stream.read(100 * 1024 * 1024):
|
||||
# pass
|
||||
|
||||
dtypes = {
|
||||
1: np.uint8,
|
||||
2: np.int8,
|
||||
3: np.int16,
|
||||
4: np.int32,
|
||||
5: np.int64,
|
||||
6: float,
|
||||
7: np.double,
|
||||
8: np.uint16,
|
||||
}
|
||||
|
||||
def code(dtype):
|
||||
for k in dtypes.keys():
|
||||
if dtypes[k] == dtype:
|
||||
return k
|
||||
raise ValueError(dtype)
|
||||
|
||||
def index_file_path(prefix_path):
|
||||
return prefix_path + ".idx"
|
||||
|
||||
def data_file_path(prefix_path):
|
||||
return prefix_path + ".bin"
|
||||
|
||||
class MMapIndexedDataset(torch.utils.data.Dataset):
|
||||
class Index(object):
|
||||
_HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
|
||||
@classmethod
|
||||
def writer(cls, path, dtype):
|
||||
class _Writer(object):
|
||||
def __enter__(self):
|
||||
self._file = open(path, "wb")
|
||||
|
||||
# Write Magic string so we can check the file format then opening it again.
|
||||
self._file.write(cls._HDR_MAGIC)
|
||||
# Write version number
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", 1))
|
||||
# Little endian unsigned 8 Bit integer
|
||||
self._file.write(struct.pack("<B", code(dtype)))
|
||||
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _get_pointers(sizes):
|
||||
dtype_size = dtype().itemsize
|
||||
address = 0
|
||||
pointers = []
|
||||
|
||||
for size in sizes:
|
||||
pointers.append(address)
|
||||
address += size * dtype_size
|
||||
|
||||
return pointers
|
||||
|
||||
def write(self, sizes, doc_idx):
|
||||
pointers = self._get_pointers(sizes)
|
||||
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", len(sizes)))
|
||||
# Little endian unsigned 64 Bit integer
|
||||
self._file.write(struct.pack("<Q", len(doc_idx)))
|
||||
|
||||
sizes = np.array(sizes, dtype=np.int32)
|
||||
self._file.write(sizes.tobytes(order="C"))
|
||||
del sizes
|
||||
|
||||
pointers = np.array(pointers, dtype=np.int64)
|
||||
self._file.write(pointers.tobytes(order="C"))
|
||||
del pointers
|
||||
|
||||
doc_idx = np.array(doc_idx, dtype=np.int64)
|
||||
self._file.write(doc_idx.tobytes(order="C"))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._file.close()
|
||||
|
||||
return _Writer()
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
with open(path, "rb") as stream:
|
||||
magic_test = stream.read(9)
|
||||
assert self._HDR_MAGIC == magic_test, (
|
||||
"Index file doesn't match expected format. "
|
||||
"Make sure that --dataset-impl is configured properly."
|
||||
)
|
||||
# Little endian unsigned 64 Bit integer
|
||||
version = struct.unpack("<Q", stream.read(8))
|
||||
assert (1,) == version
|
||||
|
||||
# Little endian unsigned 8 Bit integer
|
||||
(dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
self._dtype = dtypes[dtype_code]
|
||||
self._dtype_size = self._dtype().itemsize
|
||||
|
||||
self._len = struct.unpack("<Q", stream.read(8))[0]
|
||||
self._doc_count = struct.unpack("<Q", stream.read(8))[0]
|
||||
offset = stream.tell()
|
||||
|
||||
if not skip_warmup:
|
||||
print_rank_0(" warming up index mmap file...")
|
||||
_warmup_mmap_file(path)
|
||||
|
||||
self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
print_rank_0(" reading sizes...")
|
||||
self._sizes = np.frombuffer(
|
||||
self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
|
||||
)
|
||||
print_rank_0(" reading pointers...")
|
||||
self._pointers = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._len,
|
||||
offset=offset + self._sizes.nbytes,
|
||||
)
|
||||
print_rank_0(" reading document index...")
|
||||
self._doc_idx = np.frombuffer(
|
||||
self._bin_buffer,
|
||||
dtype=np.int64,
|
||||
count=self._doc_count,
|
||||
offset=offset + self._sizes.nbytes + self._pointers.nbytes,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._doc_idx
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def __getitem__(self, i):
|
||||
return self._pointers[i], self._sizes[i]
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __init__(self, path, skip_warmup=False):
|
||||
super().__init__()
|
||||
|
||||
self._path = None
|
||||
self._index = None
|
||||
self._bin_buffer = None
|
||||
|
||||
self._do_init(path, skip_warmup)
|
||||
|
||||
def __getstate__(self):
|
||||
return self._path
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._do_init(state)
|
||||
|
||||
def _do_init(self, path, skip_warmup):
|
||||
self._path = path
|
||||
self._index = self.Index(index_file_path(self._path), skip_warmup)
|
||||
|
||||
if not skip_warmup:
|
||||
print_rank_0(" warming up data mmap file...")
|
||||
_warmup_mmap_file(data_file_path(self._path))
|
||||
print_rank_0(" creating numpy buffer of mmap...")
|
||||
self._bin_buffer_mmap = np.memmap(
|
||||
data_file_path(self._path), mode="r", order="C"
|
||||
)
|
||||
print_rank_0(" creating memory view of numpy buffer...")
|
||||
self._bin_buffer = memoryview(self._bin_buffer_mmap)
|
||||
|
||||
def __del__(self):
|
||||
self._bin_buffer_mmap._mmap.close()
|
||||
del self._bin_buffer_mmap
|
||||
del self._index
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index)
|
||||
|
||||
# @lru_cache(maxsize=8)
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
ptr, size = self._index[idx]
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
if step != 1:
|
||||
raise ValueError(
|
||||
"Slices into indexed_dataset must be contiguous")
|
||||
ptr = self._index._pointers[start]
|
||||
sizes = self._index._sizes[idx]
|
||||
offsets = list(accumulate(sizes))
|
||||
total_size = sum(sizes)
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr
|
||||
)
|
||||
sents = np.split(np_array, offsets[:-1])
|
||||
return sents
|
||||
|
||||
def get(self, idx, offset=0, length=None):
|
||||
"""Retrieves a single item from the dataset with the option to only
|
||||
return a portion of the item.
|
||||
|
||||
get(idx) is the same as [idx] but get() does not support slicing.
|
||||
"""
|
||||
ptr, size = self._index[idx]
|
||||
if length is None:
|
||||
length = size - offset
|
||||
ptr += offset * np.dtype(self._index.dtype).itemsize
|
||||
np_array = np.frombuffer(
|
||||
self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr
|
||||
)
|
||||
return np_array
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._index.sizes
|
||||
|
||||
@property
|
||||
def doc_idx(self):
|
||||
return self._index.doc_idx
|
||||
|
||||
def get_doc_idx(self):
|
||||
return self._index._doc_idx
|
||||
|
||||
def set_doc_idx(self, doc_idx_):
|
||||
self._index._doc_idx = doc_idx_
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def exists(path):
|
||||
return os.path.exists(index_file_path(path)) and os.path.exists(
|
||||
data_file_path(path)
|
||||
)
|
224
finetune/lora/src/dataset.py
vendored
Normal file
224
finetune/lora/src/dataset.py
vendored
Normal file
@ -0,0 +1,224 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import json, math, random, os, sys
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from .binidx import MMapIndexedDataset
|
||||
from .utils import MaybeIsPrime
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
if args.data_type == "binidx":
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
|
||||
|
||||
if args.data_file.endswith('/'):
|
||||
d_all = []
|
||||
for p in os.listdir(args.data_file):
|
||||
if p.endswith(".idx"):
|
||||
d_all += [p[:-4]]
|
||||
d_all.sort()
|
||||
rank_zero_info(d_all)
|
||||
exit(0)
|
||||
else:
|
||||
self.data = MMapIndexedDataset(args.data_file)
|
||||
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
|
||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||
|
||||
if args.my_qa_mask > 0:
|
||||
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
|
||||
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size
|
||||
|
||||
if args.my_pile_stage > 0:
|
||||
# assert self.data_size == 332115325534 and self.vocab_size == 50277
|
||||
self.samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||
assert self.samples_per_epoch == 40320
|
||||
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
|
||||
dataset_slot = self.data_size // args.ctx_len
|
||||
if args.my_pile_stage != 4:
|
||||
assert MaybeIsPrime(args.magic_prime)
|
||||
assert args.magic_prime % 3 == 2
|
||||
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
|
||||
elif args.data_type == "numpy":
|
||||
self.data = np.load(args.data_file).astype("int")
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
||||
self.data_size = len(self.data)
|
||||
rank_zero_info(f"Data has {self.data_size} tokens.")
|
||||
elif args.data_type == "uint16":
|
||||
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
|
||||
self.vocab_size = args.vocab_size
|
||||
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
|
||||
self.data_size = self.data.shape[0]
|
||||
rank_zero_info(f"Data has {self.data_size} samples.")
|
||||
elif args.data_type == "wds_img":
|
||||
self.vocab_size = -1
|
||||
self.data_size = -1
|
||||
self.data = None
|
||||
self.error_count = 0
|
||||
else:
|
||||
if args.data_type == "dummy":
|
||||
rank_zero_info("Building dummy data...")
|
||||
self.data = ""
|
||||
for i in range(100000):
|
||||
aa = (i) % 10000
|
||||
bb = (i * i) % 10000
|
||||
cc = aa + bb
|
||||
self.data += f".{aa}+{bb}={cc}."
|
||||
else:
|
||||
self.data = open(args.data_file, "r", encoding=args.data_type).read()
|
||||
rank_zero_info("Building token list...")
|
||||
unique = sorted(list(set(self.data)))
|
||||
self.vocab_size = len(unique)
|
||||
# rank_zero_info()
|
||||
# for u in unique:
|
||||
# print(u, end=' ')
|
||||
# rank_zero_info('\n\n')
|
||||
xx = 0
|
||||
xxObj = {}
|
||||
for u in unique:
|
||||
xxObj[xx] = u
|
||||
xx += 1
|
||||
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
|
||||
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
|
||||
self.data_size = len(self.data)
|
||||
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
|
||||
self.stoi = {ch: i for i, ch in enumerate(unique)}
|
||||
self.itos = {i: ch for i, ch in enumerate(unique)}
|
||||
|
||||
def __len__(self):
|
||||
return self.args.epoch_steps * self.args.micro_bsz
|
||||
|
||||
def __getitem__(self, idx):
|
||||
args = self.args
|
||||
rank = self.global_rank
|
||||
epoch = self.real_epoch
|
||||
world_size = self.world_size
|
||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}")
|
||||
|
||||
if args.data_type == "wds_img":
|
||||
def init_wds(self, bias=0):
|
||||
def identity(x):
|
||||
return x
|
||||
import webdataset as wds
|
||||
import torchvision.transforms as transforms
|
||||
# img_transform = transforms.Compose(
|
||||
# [transforms.CenterCrop(256)]
|
||||
# )
|
||||
img_transform = transforms.Compose([
|
||||
transforms.CenterCrop(512),
|
||||
transforms.Resize((args.my_img_size))
|
||||
])
|
||||
self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity)
|
||||
for pp in self.data_raw.pipeline:
|
||||
if 'Resampled' in str(pp):
|
||||
pp.deterministic = True
|
||||
def worker_seed():
|
||||
return rank*100000+epoch+bias*1e9
|
||||
pp.worker_seed = worker_seed
|
||||
self.data = iter(self.data_raw)
|
||||
# print(f"WebDataset loaded for rank {rank} epoch {epoch}")
|
||||
if self.data == None:
|
||||
init_wds(self)
|
||||
trial = 0
|
||||
while trial < 10:
|
||||
try:
|
||||
dd = next(self.data) # jpg, json, txt
|
||||
break
|
||||
except:
|
||||
print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]')
|
||||
self.error_count += 1
|
||||
init_wds(self, self.error_count)
|
||||
trial += 1
|
||||
pass
|
||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}")
|
||||
# with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp:
|
||||
# tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n")
|
||||
return dd[0], dd[2]
|
||||
else:
|
||||
if args.data_type == "uint16":
|
||||
i = np.random.randint(0, self.data_size-1)
|
||||
dix = self.data[i]
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||
else:
|
||||
ctx_len = args.ctx_len
|
||||
req_len = ctx_len + 1
|
||||
magic_prime = args.magic_prime
|
||||
data = self.data
|
||||
|
||||
if args.my_pile_stage > 0 and args.my_pile_stage != 4:
|
||||
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank
|
||||
|
||||
if args.my_qa_mask > 0:
|
||||
ii_orig = ii
|
||||
if ii % 2 == 0:
|
||||
ii = (ii // 2) * args.magic_prime
|
||||
if args.ctx_len == 1024:
|
||||
magic_prime = 324331313
|
||||
elif args.ctx_len == 2048:
|
||||
magic_prime = 162165671
|
||||
elif args.ctx_len == 4096:
|
||||
magic_prime = 81082817
|
||||
data = self.data_pile
|
||||
else:
|
||||
ii = ii // 2
|
||||
|
||||
factor = (math.sqrt(5) - 1) / 2
|
||||
factor = int(magic_prime * factor)
|
||||
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
|
||||
if (args.my_qa_mask == 0) or (data == self.data_pile):
|
||||
i = i + args.my_pile_shift
|
||||
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
|
||||
else:
|
||||
# cheat: pick a random spot in dataset
|
||||
i = np.random.randint(0, self.data_size - req_len)
|
||||
|
||||
if args.data_type == "binidx":
|
||||
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
|
||||
elif args.data_type == "numpy":
|
||||
dix = data[i : i + req_len]
|
||||
else:
|
||||
dix = [self.stoi[s] for s in data[i : i + req_len]]
|
||||
|
||||
if args.my_qa_mask == 1:
|
||||
if data == self.data_pile:
|
||||
z = [1] * ctx_len
|
||||
else:
|
||||
z = [0] * ctx_len
|
||||
z_sum = 0
|
||||
isGood = False
|
||||
for i in range(3, ctx_len):
|
||||
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
|
||||
isGood = True
|
||||
if dix[i] == 0:
|
||||
isGood = False
|
||||
if isGood:
|
||||
z[i] = 1
|
||||
z_sum += 1
|
||||
if z_sum == 0:
|
||||
z = [1] * ctx_len
|
||||
i = np.random.randint(0, self.data_pile_size - req_len)
|
||||
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
|
||||
z = torch.tensor(z, dtype=torch.bfloat16)
|
||||
|
||||
x = torch.tensor(dix[:-1], dtype=torch.long)
|
||||
y = torch.tensor(dix[1:], dtype=torch.long)
|
||||
|
||||
# if ii_orig < 50:
|
||||
# # if rank == 1:
|
||||
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
|
||||
# else:
|
||||
# exit(0)
|
||||
|
||||
if args.my_qa_mask == 1:
|
||||
return x, y, z
|
||||
|
||||
return x, y
|
678
finetune/lora/src/model.py
vendored
Normal file
678
finetune/lora/src/model.py
vendored
Normal file
@ -0,0 +1,678 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import functools
|
||||
import os, math, gc, importlib
|
||||
import torch
|
||||
# torch._C._jit_set_profiling_executor(True)
|
||||
# torch._C._jit_set_profiling_mode(True)
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
from torch.nn import functional as F
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||
if importlib.util.find_spec('deepspeed'):
|
||||
import deepspeed
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
|
||||
|
||||
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
|
||||
|
||||
LORA_CONFIG = {
|
||||
"r": 0,
|
||||
"alpha": 0,
|
||||
"dropout": 0,
|
||||
"parts": {"att", "ln", "time"},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"])
|
||||
except:
|
||||
os.environ["RWKV_MY_TESTING"] = ''
|
||||
|
||||
def __nop(ob):
|
||||
return ob
|
||||
|
||||
|
||||
MyModule = nn.Module
|
||||
MyFunction = __nop
|
||||
if os.environ["RWKV_JIT_ON"] == "1":
|
||||
MyModule = torch.jit.ScriptModule
|
||||
MyFunction = torch.jit.script_method
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# CUDA Kernel
|
||||
########################################################################################################
|
||||
|
||||
T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM!
|
||||
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["finetune/lora/cuda/wkv_op_bf16.cpp", "finetune/lora/cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
||||
class WKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, B, T, C, w, u, k, v):
|
||||
ctx.B = B
|
||||
ctx.T = T
|
||||
ctx.C = C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 32) == 0
|
||||
w = -torch.exp(w.float().contiguous())
|
||||
u = u.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||
ctx.save_for_backward(w, u, k, v, y)
|
||||
return y
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
B = ctx.B
|
||||
T = ctx.T
|
||||
C = ctx.C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 32) == 0
|
||||
w, u, k, v, y = ctx.saved_tensors
|
||||
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
||||
gw = torch.sum(gw, dim=0)
|
||||
gu = torch.sum(gu, dim=0)
|
||||
return (None, None, None, gw, gu, gk, gv)
|
||||
else:
|
||||
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["finetune/lora/cuda/wkv_op.cpp", "finetune/lora/cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
|
||||
class WKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, B, T, C, w, u, k, v):
|
||||
ctx.B = B
|
||||
ctx.T = T
|
||||
ctx.C = C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 32) == 0
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
w = -torch.exp(w.contiguous())
|
||||
u = u.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
else:
|
||||
w = -torch.exp(w.float().contiguous())
|
||||
u = u.float().contiguous()
|
||||
k = k.float().contiguous()
|
||||
v = v.float().contiguous()
|
||||
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
|
||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||
ctx.save_for_backward(w, u, k, v, y)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
return y
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
return y.half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
return y.bfloat16()
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
B = ctx.B
|
||||
T = ctx.T
|
||||
C = ctx.C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 32) == 0
|
||||
w, u, k, v, y = ctx.saved_tensors
|
||||
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
|
||||
else:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
|
||||
gw = torch.sum(gw, dim=0)
|
||||
gu = torch.sum(gu, dim=0)
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
return (None, None, None, gw, gu, gk, gv)
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
||||
|
||||
|
||||
def RUN_CUDA(B, T, C, w, u, k, v):
|
||||
return WKV.apply(B, T, C, w, u, k, v)
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# LoRA
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class LoraLinear(nn.Module):
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((out_features, in_features)))
|
||||
assert bias == False, "Biased LoraLinear not supported"
|
||||
|
||||
r, alpha, dropout = LORA_CONFIG["r"], LORA_CONFIG[
|
||||
"alpha"], LORA_CONFIG["dropout"]
|
||||
self.lora_A = nn.Parameter(torch.empty(r, in_features))
|
||||
self.lora_B = nn.Parameter(torch.empty(out_features, r))
|
||||
self.lora_dropout = nn.Dropout(dropout)
|
||||
self.scaling = alpha / r
|
||||
|
||||
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def forward(self, x):
|
||||
return (
|
||||
F.linear(x, self.weight) + self.scaling *
|
||||
F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B))
|
||||
|
||||
|
||||
@functools.wraps(LoraLinear)
|
||||
def make_linear_att(*args, **kwargs):
|
||||
if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
||||
return LoraLinear(*args, **kwargs)
|
||||
else:
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
@functools.wraps(LoraLinear)
|
||||
def make_linear_ffn(*args, **kwargs):
|
||||
if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0:
|
||||
return LoraLinear(*args, **kwargs)
|
||||
else:
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
########################################################################################################
|
||||
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class RWKV_TimeMix(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
self.ctx_len = args.ctx_len
|
||||
self.n_embd = args.n_embd
|
||||
|
||||
with torch.no_grad(): # fancy init
|
||||
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||||
ddd = torch.ones(1, 1, args.n_embd)
|
||||
for i in range(args.n_embd):
|
||||
ddd[0, 0, i] = i / args.n_embd
|
||||
|
||||
# fancy time_decay
|
||||
decay_speed = torch.ones(args.dim_att)
|
||||
for h in range(args.dim_att):
|
||||
decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
||||
self.time_decay = nn.Parameter(decay_speed)
|
||||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
||||
|
||||
# fancy time_first
|
||||
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5
|
||||
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)
|
||||
|
||||
# fancy time_mix
|
||||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||||
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
|
||||
|
||||
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
|
||||
|
||||
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
||||
self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
||||
d_qkv = args.n_embd // 16
|
||||
self.qq = nn.Linear(args.n_embd, d_qkv, bias=False)
|
||||
self.kk = nn.Linear(args.n_embd, d_qkv, bias=False)
|
||||
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
|
||||
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
|
||||
with torch.no_grad():
|
||||
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||||
|
||||
if 'a' not in os.environ["RWKV_MY_TESTING"]:
|
||||
@MyFunction
|
||||
def jit_func(self, x):
|
||||
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
r = self.receptance(xr)
|
||||
sr = torch.sigmoid(r)
|
||||
return sr, k, v
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||
sr, k, v = self.jit_func(x)
|
||||
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
||||
return self.output(rwkv)
|
||||
|
||||
if 'a' in os.environ["RWKV_MY_TESTING"]:
|
||||
@MyFunction
|
||||
def QKV(self, q, k, v):
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.att_mask == 0, float('-inf'))
|
||||
att = F.softmax(att, dim = -1)
|
||||
x = att @ v
|
||||
return x
|
||||
|
||||
@MyFunction
|
||||
def jit_funcQKV(self, x):
|
||||
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
|
||||
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
|
||||
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
r = self.receptance(xr)
|
||||
sr = torch.sigmoid(r)
|
||||
qq = self.qq(xqq)
|
||||
kk = self.kk(xkk)
|
||||
vv = self.vv(xvv)
|
||||
return sr, k, v, qq, kk, vv
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||
sr, k, v, qq, kk, vv = self.jit_funcQKV(x)
|
||||
rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v)
|
||||
rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv))
|
||||
return rwkv
|
||||
|
||||
########################################################################################################
|
||||
|
||||
class RWKV_ChannelMix(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
with torch.no_grad(): # fancy init of time_mix
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
|
||||
ddd = torch.ones(1, 1, args.n_embd)
|
||||
for i in range(args.n_embd):
|
||||
ddd[0, 0, i] = i / args.n_embd
|
||||
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
|
||||
|
||||
self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False)
|
||||
self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False)
|
||||
self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False)
|
||||
|
||||
@MyFunction
|
||||
def forward(self, x):
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
k = self.key(xk)
|
||||
k = torch.square(torch.relu(k))
|
||||
kv = self.value(k)
|
||||
return torch.sigmoid(self.receptance(xr)) * kv
|
||||
|
||||
class MishGLU(MyModule):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
with torch.no_grad():
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)
|
||||
|
||||
x = torch.ones(1, 1, args.n_embd)
|
||||
for i in range(args.n_embd):
|
||||
x[0, 0, i] = i / args.n_embd
|
||||
|
||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||||
self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
|
||||
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
|
||||
|
||||
@MyFunction
|
||||
def forward(self, x):
|
||||
xx = self.time_shift(x)
|
||||
xa = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xb = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
a = self.aa(xa)
|
||||
b = self.bb(xb)
|
||||
return self.value(a * F.mish(b))
|
||||
|
||||
########################################################################################################
|
||||
# The RWKV Model with our blocks
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, args, layer_id):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.ln1 = nn.LayerNorm(args.n_embd)
|
||||
self.ln2 = nn.LayerNorm(args.n_embd)
|
||||
|
||||
if self.layer_id == 0:
|
||||
self.ln0 = nn.LayerNorm(args.n_embd)
|
||||
if args.my_pos_emb > 0:
|
||||
self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd)))
|
||||
self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd)))
|
||||
|
||||
if self.layer_id == 0 and self.args.pre_ffn > 0:
|
||||
self.ffnPre = RWKV_ChannelMix(args, 0)
|
||||
else:
|
||||
self.att = RWKV_TimeMix(args, layer_id)
|
||||
|
||||
if 'g' in os.environ["RWKV_MY_TESTING"]:
|
||||
self.ffn = MishGLU(args, layer_id)
|
||||
else:
|
||||
self.ffn = RWKV_ChannelMix(args, layer_id)
|
||||
|
||||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||||
self.tiny_ln = nn.LayerNorm(args.n_embd)
|
||||
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
|
||||
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
|
||||
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
||||
|
||||
def forward(self, x, x_emb=None):
|
||||
args = self.args
|
||||
B, T, C = x.size()
|
||||
if self.layer_id == 0:
|
||||
x = self.ln0(x)
|
||||
if args.my_pos_emb > 0:
|
||||
pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:]
|
||||
x = x + pos_emb
|
||||
|
||||
if self.layer_id == 0 and args.pre_ffn > 0:
|
||||
x = x + self.ffnPre(self.ln1(x))
|
||||
else:
|
||||
x = x + self.att(self.ln1(x))
|
||||
x = x + self.ffn(self.ln2(x))
|
||||
|
||||
if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
|
||||
xx = self.tiny_ln(x)
|
||||
q = self.tiny_q(xx)[:, :T, :]
|
||||
k = self.tiny_k(xx)[:, :T, :]
|
||||
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
|
||||
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
|
||||
x = x + c @ self.tiny_v(x_emb)
|
||||
return x
|
||||
|
||||
|
||||
class L2Wrap(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, loss, y):
|
||||
ctx.save_for_backward(y)
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
y = ctx.saved_tensors[0]
|
||||
# to encourage the logits to be close to 0
|
||||
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
||||
gy = torch.zeros_like(y)
|
||||
gy.scatter_(-1, ids, maxx * factor)
|
||||
return (grad_output, gy)
|
||||
|
||||
|
||||
class RWKV(pl.LightningModule):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
if not hasattr(args, 'dim_att'):
|
||||
args.dim_att = args.n_embd
|
||||
if not hasattr(args, 'dim_ffn'):
|
||||
args.dim_ffn = args.n_embd * 4
|
||||
if not hasattr(args, 'tiny_att_layer'):
|
||||
args.tiny_att_layer = -1
|
||||
if not hasattr(args, 'tiny_att_dim'):
|
||||
args.tiny_att_dim = -1
|
||||
|
||||
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
|
||||
|
||||
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
|
||||
|
||||
self.ln_out = nn.LayerNorm(args.n_embd)
|
||||
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
|
||||
|
||||
if args.head_qk > 0:
|
||||
self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||
self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False)
|
||||
self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
|
||||
|
||||
def configure_optimizers(self):
|
||||
args = self.args
|
||||
if args.layerwise_lr > 0:
|
||||
lr_1x = set()
|
||||
lr_2x = set()
|
||||
lr_3x = set()
|
||||
for n, p in self.named_parameters():
|
||||
if "time_mix" in n:
|
||||
if args.my_pile_stage == 2:
|
||||
lr_2x.add(n)
|
||||
else:
|
||||
lr_1x.add(n)
|
||||
elif "time_decay" in n:
|
||||
if args.my_pile_stage == 2:
|
||||
lr_3x.add(n)
|
||||
else:
|
||||
lr_2x.add(n)
|
||||
elif "time_first" in n:
|
||||
lr_3x.add(n)
|
||||
else:
|
||||
lr_1x.add(n)
|
||||
lr_1x = sorted(list(lr_1x))
|
||||
lr_2x = sorted(list(lr_2x))
|
||||
lr_3x = sorted(list(lr_3x))
|
||||
# print('1x', lr_1x)
|
||||
# print('2x', lr_2x)
|
||||
# print('3x', lr_3x)
|
||||
param_dict = {n: p for n, p in self.named_parameters()}
|
||||
if args.my_pile_stage == 2:
|
||||
optim_groups = [
|
||||
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
||||
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init},
|
||||
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init},
|
||||
]
|
||||
else:
|
||||
optim_groups = [
|
||||
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0},
|
||||
{"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0},
|
||||
{"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0},
|
||||
]
|
||||
else:
|
||||
optim_groups = [
|
||||
{"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
for g in optim_groups:
|
||||
g["params"] = [p for p in g["params"] if p.requires_grad]
|
||||
optim_groups = [g for g in optim_groups if len(g["params"]) > 0]
|
||||
|
||||
if self.deepspeed_offload:
|
||||
return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False)
|
||||
return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
||||
# return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False)
|
||||
|
||||
@property
|
||||
def deepspeed_offload(self) -> bool:
|
||||
strategy = self.trainer.strategy
|
||||
if isinstance(strategy, DeepSpeedStrategy):
|
||||
cfg = strategy.config["zero_optimization"]
|
||||
return cfg.get("offload_optimizer") or cfg.get("offload_param")
|
||||
return False
|
||||
|
||||
def forward(self, idx):
|
||||
args = self.args
|
||||
B, T = idx.size()
|
||||
assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
|
||||
|
||||
x = self.emb(idx)
|
||||
x_emb = x
|
||||
|
||||
if args.tiny_att_dim > 0:
|
||||
for block in self.blocks:
|
||||
if args.grad_cp == 1:
|
||||
if args.lora:
|
||||
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
||||
else:
|
||||
x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
|
||||
else:
|
||||
x = block(x, x_emb)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
if args.grad_cp == 1:
|
||||
if args.lora:
|
||||
x = torch_checkpoint(block, x, x_emb, use_reentrant=False)
|
||||
else:
|
||||
x = deepspeed.checkpointing.checkpoint(block, x)
|
||||
else:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_out(x)
|
||||
|
||||
if args.head_qk > 0:
|
||||
q = self.head_q(x)[:, :T, :]
|
||||
k = self.head_k(x)[:, :T, :]
|
||||
c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
|
||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||||
|
||||
if "32" in os.environ["RWKV_FLOAT_MODE"]:
|
||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size)
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
|
||||
|
||||
x = self.head(x) + c
|
||||
else:
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
args = self.args
|
||||
if args.my_qa_mask != 1:
|
||||
idx, targets = batch
|
||||
logits = self(idx)
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
else:
|
||||
idx, targets, mask = batch
|
||||
mask = mask.view(-1)
|
||||
sum_mask = torch.sum(mask).item()
|
||||
# if sum_mask == 0:
|
||||
# return torch.tensor([0.0], requires_grad=True)
|
||||
|
||||
logits = self(idx)
|
||||
if sum_mask == mask.shape[0]:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
||||
# print('rank', self.global_rank, 'loss', loss.item())
|
||||
else:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
|
||||
# loss_raw = loss
|
||||
loss = torch.sum(loss * mask) / sum_mask
|
||||
|
||||
# torch.set_printoptions(threshold=10000)
|
||||
# if True: #self.global_rank == 1:
|
||||
# tmp = ''
|
||||
# sss = 0
|
||||
# ccc = 0
|
||||
# for i in range(mask.shape[0]):
|
||||
# if mask[i] > 0:
|
||||
# tmp += str(idx.view(-1)[i].item()) + ','
|
||||
# sss += loss_raw.view(-1)[i].float().item()
|
||||
# ccc += 1
|
||||
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)
|
||||
|
||||
return L2Wrap.apply(loss, logits)
|
||||
|
||||
def training_step_end(self, batch_parts):
|
||||
all = self.all_gather(batch_parts)
|
||||
if self.trainer.is_global_zero:
|
||||
self.trainer.my_loss_all = all
|
||||
|
||||
def generate_init_weight(self):
|
||||
print(
|
||||
f"""
|
||||
############################################################################
|
||||
#
|
||||
# Init model weight (slow for large models)...
|
||||
#
|
||||
############################################################################
|
||||
"""
|
||||
)
|
||||
m = {}
|
||||
for n in self.state_dict():
|
||||
p = self.state_dict()[n]
|
||||
shape = p.shape
|
||||
|
||||
gain = 1.0
|
||||
scale = 1.0
|
||||
if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n:
|
||||
m[n] = p
|
||||
else:
|
||||
if n == "emb.weight":
|
||||
scale = -1 * self.args.lr_init
|
||||
else:
|
||||
if shape[0] > shape[1]:
|
||||
gain = math.sqrt(shape[0] / shape[1])
|
||||
for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']:
|
||||
if kk in n:
|
||||
scale = 0
|
||||
if n == "head.weight":
|
||||
scale = 0.5
|
||||
if "head_k." in n:
|
||||
scale = 0.1
|
||||
if "head_q." in n:
|
||||
scale = 0
|
||||
|
||||
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}")
|
||||
|
||||
if self.args.accelerator.upper() == "GPU":
|
||||
m[n] = torch.empty((shape[0], shape[1]), device="cuda")
|
||||
else:
|
||||
m[n] = torch.empty((shape[0], shape[1]))
|
||||
|
||||
if scale == 0:
|
||||
nn.init.zeros_(m[n])
|
||||
elif scale < 0:
|
||||
nn.init.uniform_(m[n], a=scale, b=-scale)
|
||||
else:
|
||||
nn.init.orthogonal_(m[n], gain=gain * scale)
|
||||
|
||||
m[n] = m[n].cpu()
|
||||
if os.environ["RWKV_FLOAT_MODE"] == "fp16":
|
||||
m[n] = m[n].half()
|
||||
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
|
||||
m[n] = m[n].bfloat16()
|
||||
|
||||
# if n == "emb.weight":
|
||||
# print(m[n])
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return m
|
203
finetune/lora/src/trainer.py
vendored
Normal file
203
finetune/lora/src/trainer.py
vendored
Normal file
@ -0,0 +1,203 @@
|
||||
import os, math, time, datetime, subprocess
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
from .model import LORA_CONFIG
|
||||
|
||||
def my_save(dd, ff):
|
||||
if '14b-run1' not in ff:
|
||||
torch.save(dd, ff)
|
||||
else:
|
||||
fn = ff.split('/')[-1]
|
||||
fff = '/dev/shm/' + fn
|
||||
torch.save(dd, fff)
|
||||
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
|
||||
|
||||
class train_callback(pl.Callback):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
||||
args = self.args
|
||||
# if args.cuda_cleanup > 0:
|
||||
# torch.cuda.empty_cache()
|
||||
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
||||
|
||||
# LR schedule
|
||||
w_step = args.warmup_steps
|
||||
if args.lr_final == args.lr_init or args.epoch_count == 0:
|
||||
lr = args.lr_init
|
||||
else:
|
||||
decay_step = real_step - args.my_pile_edecay * args.epoch_steps
|
||||
decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps
|
||||
progress = (decay_step - w_step + 1) / (decay_total - w_step)
|
||||
progress = min(1, max(0, progress))
|
||||
|
||||
if args.lr_final == 0 or args.lr_init == 0: # linear decay
|
||||
lr = args.lr_init + (args.lr_final - args.lr_init) * progress
|
||||
else: # exp decay
|
||||
lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1))
|
||||
|
||||
if trainer.global_step < w_step:
|
||||
lr = lr * (0.2 + 0.8 * trainer.global_step / w_step)
|
||||
# if trainer.is_global_zero:
|
||||
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)
|
||||
|
||||
for param_group in trainer.optimizers[0].param_groups:
|
||||
if args.layerwise_lr > 0:
|
||||
param_group["lr"] = lr * param_group["my_lr_scale"]
|
||||
# print(param_group["lr"], param_group["my_lr_scale"])
|
||||
else:
|
||||
param_group["lr"] = lr
|
||||
|
||||
trainer.my_lr = lr
|
||||
# rank_zero_info(f"{real_step} {lr}")
|
||||
|
||||
if trainer.global_step == 0:
|
||||
if trainer.is_global_zero: # logging
|
||||
trainer.my_loss_sum = 0
|
||||
trainer.my_loss_count = 0
|
||||
trainer.my_log = open(args.proj_dir + "/train_log.txt", "a")
|
||||
trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n")
|
||||
try:
|
||||
print(f"\n{trainer.strategy.config}\n")
|
||||
trainer.my_log.write(f"{trainer.strategy.config}\n")
|
||||
except:
|
||||
pass
|
||||
trainer.my_log.flush()
|
||||
if len(args.wandb) > 0:
|
||||
print("Login to wandb...")
|
||||
import wandb
|
||||
wandb.init(
|
||||
project=args.wandb,
|
||||
name=args.run_name + " " + args.my_timestamp,
|
||||
config=args,
|
||||
save_code=False,
|
||||
)
|
||||
trainer.my_wandb = wandb
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
args = self.args
|
||||
if trainer.is_global_zero: # logging
|
||||
t_now = time.time_ns()
|
||||
token_per_step = args.ctx_len * args.real_bsz
|
||||
real_step = trainer.global_step + args.epoch_begin * args.epoch_steps
|
||||
kt_s = 0
|
||||
try:
|
||||
t_cost = (t_now - trainer.my_time_ns) / 1e9
|
||||
kt_s = token_per_step / t_cost / 1000
|
||||
self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True)
|
||||
self.log("Kt/s", kt_s, prog_bar=True, on_step=True)
|
||||
except:
|
||||
pass
|
||||
trainer.my_time_ns = t_now
|
||||
trainer.my_loss = trainer.my_loss_all.float().mean().item()
|
||||
trainer.my_loss_sum += trainer.my_loss
|
||||
trainer.my_loss_count += 1
|
||||
trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count
|
||||
self.log("lr", trainer.my_lr, prog_bar=True, on_step=True)
|
||||
self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True)
|
||||
# self.log("s", real_step, prog_bar=True, on_step=True)
|
||||
|
||||
if len(args.wandb) > 0:
|
||||
lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9}
|
||||
if kt_s > 0:
|
||||
lll["kt/s"] = kt_s
|
||||
trainer.my_wandb.log(lll, step=int(real_step))
|
||||
if args.magic_prime > 0:
|
||||
expand_factor = 2 if args.my_qa_mask > 0 else 1
|
||||
if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1:
|
||||
to_save_dict = pl_module.state_dict()
|
||||
my_save(
|
||||
to_save_dict,
|
||||
f"{args.proj_dir}/rwkv-final.pth",
|
||||
)
|
||||
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
args = self.args
|
||||
dataset = trainer.train_dataloader.dataset.datasets
|
||||
assert "MyDataset" in str(dataset)
|
||||
dataset.global_rank = trainer.global_rank
|
||||
dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch)
|
||||
dataset.world_size = trainer.world_size
|
||||
# print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########')
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
args = self.args
|
||||
if trainer.is_global_zero: # logging & save state_dict
|
||||
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
|
||||
if args.data_type == 'wds_img':
|
||||
raw_dict = pl_module.state_dict()
|
||||
to_save_dict = {}
|
||||
for k in raw_dict:
|
||||
if k.startswith('encoder.') or k.startswith('decoder.'):
|
||||
to_save_dict[k] = raw_dict[k]
|
||||
else:
|
||||
to_save_dict = pl_module.state_dict()
|
||||
|
||||
if args.lora:
|
||||
enable_time_finetune = 'time' in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
|
||||
lora_dict = {}
|
||||
for name, state in to_save_dict.items():
|
||||
if ('.lora_' in name
|
||||
or (enable_time_finetune and '.time_' in name)
|
||||
or (enable_ln_finetune and '.ln' in name)):
|
||||
lora_dict[name] = state
|
||||
to_save_dict = lora_dict
|
||||
|
||||
try:
|
||||
my_save(
|
||||
to_save_dict,
|
||||
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
|
||||
)
|
||||
except Exception as e:
|
||||
print('Error\n\n', e, '\n\n')
|
||||
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
|
||||
trainer.my_log.flush()
|
||||
|
||||
trainer.my_loss_sum = 0
|
||||
trainer.my_loss_count = 0
|
||||
|
||||
|
||||
@rank_zero_only
|
||||
def generate_init_weight(model, init_weight_name):
|
||||
mm = model.generate_init_weight()
|
||||
|
||||
if model.args.my_pile_stage == 1:
|
||||
if len(model.args.load_model) > 0:
|
||||
print(f"Combine weights from {model.args.load_model}...")
|
||||
load_dict = torch.load(model.args.load_model, map_location="cpu")
|
||||
for k in load_dict:
|
||||
assert k in mm
|
||||
src = load_dict[k]
|
||||
try:
|
||||
mm[k] = src.reshape(mm[k].shape)
|
||||
except:
|
||||
tmp = mm[k].squeeze().clone()
|
||||
print(k, src.shape, '-->', mm[k].shape)
|
||||
ss = src.shape[0]
|
||||
dd = tmp.shape[0]
|
||||
for i in range(dd):
|
||||
pos = i / dd * ss
|
||||
if pos >= ss - 1:
|
||||
tmp[i] = src[ss-1]
|
||||
else:
|
||||
p0 = int(math.floor(pos))
|
||||
ii = pos - p0
|
||||
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
|
||||
mm[k] = tmp.reshape(mm[k].shape)
|
||||
sss = src.squeeze().float().cpu().numpy()
|
||||
print(sss[:10], '...', sss[-10:])
|
||||
mmm = mm[k].squeeze().float().cpu().numpy()
|
||||
print(mmm[:10], '...', mmm[-10:])
|
||||
|
||||
print(f"Save to {init_weight_name}...")
|
||||
torch.save(mm, init_weight_name)
|
||||
|
||||
if model.args.my_pile_stage == 1:
|
||||
print("Done. Now go for stage 2.")
|
||||
exit(0)
|
130
finetune/lora/src/utils.py
vendored
Normal file
130
finetune/lora/src/utils.py
vendored
Normal file
@ -0,0 +1,130 @@
|
||||
import json, time, random, os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
time_slot = {}
|
||||
time_ref = time.time_ns()
|
||||
|
||||
def record_time(name):
|
||||
if name not in time_slot:
|
||||
time_slot[name] = 1e20
|
||||
tt = (time.time_ns() - time_ref) / 1e9
|
||||
if tt < time_slot[name]:
|
||||
time_slot[name] = tt
|
||||
|
||||
class TOKENIZER():
|
||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
||||
if 'list' in str(type(WORD_NAME)):
|
||||
self.charMode = False
|
||||
if WORD_NAME[0] == WORD_NAME[1]:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
||||
else:
|
||||
from transformers import GPT2TokenizerFast
|
||||
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
||||
self.vocab_size = len(self.tokenizer)
|
||||
else:
|
||||
self.charMode = True
|
||||
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
||||
self.word_table = json.load(result_file)
|
||||
|
||||
self.vocab_size = len(self.word_table)
|
||||
|
||||
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
||||
self.itos = {int(k): v for k, v in self.word_table.items()}
|
||||
|
||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
||||
|
||||
def refine_context(self, context):
|
||||
context = context.strip().split('\n')
|
||||
for c in range(len(context)):
|
||||
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
||||
context = list(filter(lambda c: c != '', context))
|
||||
context = '\n' + ('\n'.join(context)).strip()
|
||||
if context == '':
|
||||
context = '\n'
|
||||
return context
|
||||
|
||||
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
||||
lastChar = int(x[-1])
|
||||
|
||||
probs = F.softmax(out, dim=-1)
|
||||
|
||||
if self.charMode:
|
||||
if self.itos[lastChar] == '\n':
|
||||
top_p = top_p_newline
|
||||
else:
|
||||
top_p = top_p_usual
|
||||
else:
|
||||
top_p = top_p_usual
|
||||
|
||||
if os.environ["RWKV_RUN_DEVICE"] == "cpu":
|
||||
probs = probs.numpy()
|
||||
sorted_probs = np.sort(probs)[::-1]
|
||||
cumulative_probs = np.cumsum(sorted_probs)
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
probs[probs < cutoff] = 0
|
||||
if temperature != 1.0:
|
||||
probs = probs.pow(1.0 / temperature)
|
||||
probs = probs / np.sum(probs)
|
||||
out = np.random.choice(a=len(probs), p=probs)
|
||||
return out
|
||||
else:
|
||||
sorted_probs = torch.sort(probs, descending=True)[0]
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
probs[probs < cutoff] = 0
|
||||
if temperature != 1.0:
|
||||
probs = probs.pow(1.0 / temperature)
|
||||
out = torch.multinomial(probs, num_samples=1)[0]
|
||||
return out
|
||||
|
||||
def MaybeIsPrime(number):
|
||||
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def FermatPrimalityTest(number):
|
||||
if number > 1:
|
||||
for time in range(3):
|
||||
randomNumber = random.randint(2, number) - 1
|
||||
if pow(randomNumber, number - 1, number) != 1:
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def MillerRabinPrimalityTest(number):
|
||||
if number == 2:
|
||||
return True
|
||||
elif number == 1 or number % 2 == 0:
|
||||
return False
|
||||
oddPartOfNumber = number - 1
|
||||
timesTwoDividNumber = 0
|
||||
while oddPartOfNumber % 2 == 0:
|
||||
oddPartOfNumber = oddPartOfNumber // 2
|
||||
timesTwoDividNumber = timesTwoDividNumber + 1
|
||||
|
||||
for time in range(3):
|
||||
while True:
|
||||
randomNumber = random.randint(2, number) - 1
|
||||
if randomNumber != 0 and randomNumber != 1:
|
||||
break
|
||||
|
||||
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number)
|
||||
|
||||
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1):
|
||||
iterationNumber = 1
|
||||
|
||||
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1):
|
||||
randomNumberWithPower = pow(randomNumberWithPower, 2, number)
|
||||
iterationNumber = iterationNumber + 1
|
||||
if randomNumberWithPower != (number - 1):
|
||||
return False
|
||||
|
||||
return True
|
388
finetune/lora/train.py
vendored
Normal file
388
finetune/lora/train.py
vendored
Normal file
@ -0,0 +1,388 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
if __name__ == "__main__":
|
||||
from argparse import ArgumentParser
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
|
||||
rank_zero_info("########## work in progress ##########")
|
||||
|
||||
########################################################################################################
|
||||
#
|
||||
# example: train a simple L12-D768 RWKV on dummy data
|
||||
#
|
||||
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
||||
# --data_file "" --data_type "dummy" --vocab_size 0 \
|
||||
# --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \
|
||||
# --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \
|
||||
# --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
||||
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
||||
|
||||
# example: train a simple L6-D512 RWKV from scratch on enwik8
|
||||
#
|
||||
# python train.py --load_model "" --wandb "" --proj_dir "out" \
|
||||
# --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \
|
||||
# --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \
|
||||
# --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \
|
||||
# --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
|
||||
# --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
|
||||
|
||||
# example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M
|
||||
#
|
||||
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
||||
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
||||
# --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \
|
||||
# --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
|
||||
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
|
||||
# --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0
|
||||
|
||||
# example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow
|
||||
#
|
||||
# python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \
|
||||
# --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \
|
||||
# --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \
|
||||
# --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
|
||||
# --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
|
||||
# --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument("--load_model", default="", type=str) # full path, with .pth
|
||||
parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb
|
||||
parser.add_argument("--proj_dir", default="out", type=str)
|
||||
parser.add_argument("--random_seed", default="-1", type=int)
|
||||
|
||||
parser.add_argument("--data_file", default="", type=str)
|
||||
parser.add_argument("--data_type", default="utf-8", type=str)
|
||||
parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data)
|
||||
|
||||
parser.add_argument("--ctx_len", default=1024, type=int)
|
||||
parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps
|
||||
parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final
|
||||
parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x
|
||||
parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs"
|
||||
|
||||
parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU)
|
||||
parser.add_argument("--n_layer", default=6, type=int)
|
||||
parser.add_argument("--n_embd", default=512, type=int)
|
||||
parser.add_argument("--dim_att", default=0, type=int)
|
||||
parser.add_argument("--dim_ffn", default=0, type=int)
|
||||
parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better)
|
||||
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
|
||||
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
|
||||
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
|
||||
|
||||
parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
|
||||
parser.add_argument("--lr_final", default=1e-5, type=float)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model
|
||||
parser.add_argument("--beta1", default=0.9, type=float)
|
||||
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
|
||||
parser.add_argument("--adam_eps", default=1e-8, type=float)
|
||||
|
||||
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
|
||||
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
|
||||
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
|
||||
parser.add_argument("--my_pile_edecay", default=0, type=int)
|
||||
parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s)
|
||||
parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough
|
||||
# parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful)
|
||||
|
||||
parser.add_argument("--my_img_version", default=0, type=str)
|
||||
parser.add_argument("--my_img_size", default=0, type=int)
|
||||
parser.add_argument("--my_img_bit", default=0, type=int)
|
||||
parser.add_argument("--my_img_clip", default='x', type=str)
|
||||
parser.add_argument("--my_img_clip_scale", default=1, type=float)
|
||||
parser.add_argument("--my_img_l1_scale", default=0, type=float)
|
||||
parser.add_argument("--my_img_encoder", default='x', type=str)
|
||||
# parser.add_argument("--my_img_noise_scale", default=0, type=float)
|
||||
parser.add_argument("--my_sample_len", default=0, type=int)
|
||||
parser.add_argument("--my_ffn_shift", default=1, type=int)
|
||||
parser.add_argument("--my_att_shift", default=1, type=int)
|
||||
parser.add_argument("--my_pos_emb", default=0, type=int)
|
||||
parser.add_argument("--load_partial", default=0, type=int)
|
||||
parser.add_argument("--magic_prime", default=0, type=int)
|
||||
parser.add_argument("--my_qa_mask", default=0, type=int)
|
||||
parser.add_argument("--my_testing", default='', type=str)
|
||||
|
||||
parser.add_argument("--lora", action="store_true")
|
||||
parser.add_argument("--lora_load", default="", type=str)
|
||||
parser.add_argument("--lora_r", default=8, type=int)
|
||||
parser.add_argument("--lora_alpha", default=32, type=float)
|
||||
parser.add_argument("--lora_dropout", default=0.01, type=float)
|
||||
parser.add_argument("--lora_parts", default="att,ln,time", type=str)
|
||||
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
########################################################################################################
|
||||
|
||||
import os, warnings, math, datetime, sys, time, importlib
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
if "deepspeed" in args.strategy:
|
||||
import deepspeed
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
if args.random_seed >= 0:
|
||||
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
|
||||
seed_everything(args.random_seed)
|
||||
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
|
||||
warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
|
||||
# os.environ["WDS_SHOW_SEED"] = "1"
|
||||
|
||||
args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
args.enable_checkpointing = False
|
||||
args.replace_sampler_ddp = False
|
||||
args.logger = False
|
||||
args.gradient_clip_val = 1.0
|
||||
args.num_sanity_val_steps = 0
|
||||
args.check_val_every_n_epoch = int(1e20)
|
||||
args.log_every_n_steps = int(1e20)
|
||||
args.max_epochs = -1 # continue forever
|
||||
args.betas = (args.beta1, args.beta2)
|
||||
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
|
||||
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
|
||||
os.environ["RWKV_MY_TESTING"] = args.my_testing
|
||||
if args.dim_att <= 0:
|
||||
args.dim_att = args.n_embd
|
||||
if args.dim_ffn <= 0:
|
||||
args.dim_ffn = args.n_embd * 4
|
||||
|
||||
if args.data_type == "wds_img":
|
||||
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
|
||||
args.proj_dir = f"{args.proj_dir}-{args.run_name}"
|
||||
else:
|
||||
args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"
|
||||
if not os.path.exists(args.proj_dir):
|
||||
os.makedirs(args.proj_dir)
|
||||
|
||||
if args.my_pile_stage > 0:
|
||||
magic_prime_bak = args.magic_prime
|
||||
if args.ctx_len == 1024:
|
||||
args.magic_prime = 324331313
|
||||
args.epoch_count = 8043
|
||||
elif args.ctx_len == 2048:
|
||||
args.magic_prime = 162165671
|
||||
args.epoch_count = 4021
|
||||
elif args.ctx_len == 4096:
|
||||
args.magic_prime = 81082817
|
||||
args.epoch_count = 2010
|
||||
if args.my_pile_shift < 0:
|
||||
if args.ctx_len == 1024:
|
||||
args.my_pile_shift = 0
|
||||
elif args.ctx_len == 2048:
|
||||
args.my_pile_shift = 512
|
||||
elif args.ctx_len == 4096:
|
||||
args.my_pile_shift = 768
|
||||
|
||||
if magic_prime_bak > 0:
|
||||
args.magic_prime = magic_prime_bak
|
||||
|
||||
args.epoch_steps = 40320 // args.real_bsz
|
||||
assert args.epoch_steps * args.real_bsz == 40320
|
||||
if args.my_pile_stage == 2:
|
||||
assert args.lr_final == args.lr_init
|
||||
if args.my_pile_stage >= 2: # find latest saved model
|
||||
list_p = []
|
||||
for p in os.listdir(args.proj_dir):
|
||||
if p.startswith("rwkv") and p.endswith(".pth"):
|
||||
p = ((p.split("-"))[1].split("."))[0]
|
||||
if p == "init":
|
||||
p = -1
|
||||
else:
|
||||
p = int(p)
|
||||
list_p += [p]
|
||||
list_p.sort()
|
||||
max_p = list_p[-1]
|
||||
if len(list_p) > 1:
|
||||
args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted
|
||||
if max_p == -1:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||
else:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||
if args.my_pile_stage == 2:
|
||||
args.warmup_steps = 10
|
||||
else:
|
||||
args.warmup_steps = 30
|
||||
args.epoch_begin = max_p + 1
|
||||
|
||||
samples_per_epoch = args.epoch_steps * args.real_bsz
|
||||
tokens_per_epoch = samples_per_epoch * args.ctx_len
|
||||
rank_zero_info(
|
||||
f"""
|
||||
############################################################################
|
||||
#
|
||||
# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''}
|
||||
#
|
||||
# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir}
|
||||
#
|
||||
# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch
|
||||
#
|
||||
# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens
|
||||
#
|
||||
# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len
|
||||
# LoRA = {f'enabled, {args.lora_r} r, {args.lora_alpha} alpha, {args.lora_dropout} dropout, on {args.lora_parts}' if args.lora else 'disabled'}
|
||||
#
|
||||
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
|
||||
#
|
||||
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
|
||||
# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
|
||||
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
|
||||
#
|
||||
############################################################################
|
||||
"""
|
||||
)
|
||||
rank_zero_info(str(vars(args)) + "\n")
|
||||
|
||||
assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]
|
||||
|
||||
if args.lr_final == 0 or args.lr_init == 0:
|
||||
rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n")
|
||||
|
||||
assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
|
||||
os.environ["RWKV_FLOAT_MODE"] = args.precision
|
||||
if args.precision == "fp32":
|
||||
for i in range(10):
|
||||
rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n")
|
||||
if args.precision == "fp16":
|
||||
rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n")
|
||||
|
||||
os.environ["RWKV_JIT_ON"] = "1"
|
||||
if "deepspeed_stage_3" in args.strategy:
|
||||
os.environ["RWKV_JIT_ON"] = "0"
|
||||
if args.lora and args.grad_cp == 1:
|
||||
print('!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it')
|
||||
os.environ["RWKV_JIT_ON"] = "0"
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
if args.precision == "fp32":
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
else:
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if "32" in args.precision:
|
||||
args.precision = 32
|
||||
elif args.precision == "fp16":
|
||||
args.precision = 16
|
||||
else:
|
||||
args.precision = "bf16"
|
||||
|
||||
########################################################################################################
|
||||
|
||||
from src.trainer import train_callback, generate_init_weight
|
||||
from src.dataset import MyDataset
|
||||
|
||||
train_data = MyDataset(args)
|
||||
args.vocab_size = train_data.vocab_size
|
||||
|
||||
if args.data_type == 'wds_img':
|
||||
from src.model_img import RWKV_IMG
|
||||
assert args.lora, "LoRA not yet supported for RWKV_IMG"
|
||||
model = RWKV_IMG(args)
|
||||
else:
|
||||
from src.model import RWKV, LORA_CONFIG, LoraLinear
|
||||
if args.lora:
|
||||
assert args.lora_r > 0, "LoRA should have its `r` > 0"
|
||||
LORA_CONFIG["r"] = args.lora_r
|
||||
LORA_CONFIG["alpha"] = args.lora_alpha
|
||||
LORA_CONFIG["dropout"] = args.lora_dropout
|
||||
LORA_CONFIG["parts"] = set(str(args.lora_parts).split(','))
|
||||
enable_time_finetune = 'time' in LORA_CONFIG["parts"]
|
||||
enable_ln_finetune = 'ln' in LORA_CONFIG["parts"]
|
||||
model = RWKV(args)
|
||||
# only train lora parameters
|
||||
if args.lora:
|
||||
model.requires_grad_(False)
|
||||
for name, module in model.named_modules():
|
||||
# have to check param name since it may have been wrapped by torchscript
|
||||
if any(n.startswith("lora_") for n, _ in module.named_parameters()):
|
||||
print(f' LoRA training module {name}')
|
||||
for pname, param in module.named_parameters():
|
||||
param.requires_grad = 'lora_' in pname
|
||||
elif enable_ln_finetune and '.ln' in name:
|
||||
print(f' LoRA additionally training module {name}')
|
||||
for param in module.parameters():
|
||||
param.requires_grad = True
|
||||
elif enable_time_finetune and any(n.startswith("time") for n, _ in module.named_parameters()):
|
||||
for pname, param in module.named_parameters():
|
||||
if pname.startswith("time"):
|
||||
print(f' LoRA additionally training parameter {pname}')
|
||||
param.requires_grad = True
|
||||
|
||||
if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
|
||||
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
|
||||
generate_init_weight(model, init_weight_name) # save initial weights
|
||||
args.load_model = init_weight_name
|
||||
|
||||
rank_zero_info(f"########## Loading {args.load_model}... ##########")
|
||||
try:
|
||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
except:
|
||||
rank_zero_info(f"Bad checkpoint {args.load_model}")
|
||||
if args.my_pile_stage >= 2: # try again using another checkpoint
|
||||
max_p = args.my_pile_prev_p
|
||||
if max_p == -1:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
|
||||
else:
|
||||
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
|
||||
args.epoch_begin = max_p + 1
|
||||
rank_zero_info(f"Trying {args.load_model}")
|
||||
load_dict = torch.load(args.load_model, map_location="cpu")
|
||||
|
||||
if args.load_partial == 1:
|
||||
load_keys = load_dict.keys()
|
||||
for k in model.state_dict():
|
||||
if k not in load_keys:
|
||||
load_dict[k] = model.state_dict()[k]
|
||||
# If using LoRA, the LoRA keys might be missing in the original model
|
||||
model.load_state_dict(load_dict, strict=(not args.lora))
|
||||
if os.path.isfile(args.lora_load):
|
||||
model.load_state_dict(torch.load(args.lora_load, map_location="cpu"),
|
||||
strict=False)
|
||||
|
||||
trainer: Trainer = Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[train_callback(args)],
|
||||
)
|
||||
|
||||
if (args.lr_init > 1e-4 or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8):
|
||||
if 'I_KNOW_WHAT_IM_DOING' in os.environ:
|
||||
if trainer.global_rank == 0:
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
print(f' WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
else:
|
||||
if trainer.global_rank == 0:
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
print(f' ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)')
|
||||
print(f' Unless you are sure this is what you want, adjust them accordingly')
|
||||
print(f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")')
|
||||
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
|
||||
exit(0)
|
||||
|
||||
if trainer.global_rank == 0:
|
||||
for n in model.state_dict():
|
||||
shape = model.state_dict()[n].shape
|
||||
shape = [i for i in shape if i != 1]
|
||||
if len(shape) > 1:
|
||||
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
|
||||
else:
|
||||
print(f"{str(shape[0]).ljust(5)} {n}")
|
||||
|
||||
if "deepspeed" in args.strategy:
|
||||
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
||||
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
|
||||
|
||||
# must set shuffle=False, persistent_workers=False (because worker is in another thread)
|
||||
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
|
||||
|
||||
trainer.fit(model, data_loader)
|
3
finetune/requirements.txt
Normal file
3
finetune/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
torch==1.13.1
|
||||
pytorch_lightning==1.9.5
|
||||
deepspeed
|
27
frontend/package-lock.json
generated
27
frontend/package-lock.json
generated
@ -12,6 +12,7 @@
|
||||
"@fluentui/react-icons": "^2.0.201",
|
||||
"@microsoft/fetch-event-source": "^2.0.1",
|
||||
"@primer/octicons-react": "^19.1.0",
|
||||
"chart.js": "^4.3.0",
|
||||
"classnames": "^2.3.2",
|
||||
"github-markdown-css": "^5.2.0",
|
||||
"i18next": "^22.4.15",
|
||||
@ -19,6 +20,7 @@
|
||||
"mobx-react-lite": "^3.4.3",
|
||||
"react": "^18.2.0",
|
||||
"react-beautiful-dnd": "^13.1.1",
|
||||
"react-chartjs-2": "^5.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-i18next": "^12.2.2",
|
||||
"react-markdown": "^8.0.7",
|
||||
@ -1903,6 +1905,11 @@
|
||||
"integrity": "sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@kurkle/color": {
|
||||
"version": "0.3.2",
|
||||
"resolved": "https://registry.npmjs.org/@kurkle/color/-/color-0.3.2.tgz",
|
||||
"integrity": "sha512-fuscdXJ9G1qb7W8VdHi+IwRqij3lBkosAm4ydQtEmbY58OzHXqQhvlxqEkoz0yssNVn38bcpRWgA9PP+OGoisw=="
|
||||
},
|
||||
"node_modules/@microsoft/fetch-event-source": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmmirror.com/@microsoft/fetch-event-source/-/fetch-event-source-2.0.1.tgz",
|
||||
@ -2258,6 +2265,17 @@
|
||||
"resolved": "https://registry.npmmirror.com/character-entities/-/character-entities-2.0.2.tgz",
|
||||
"integrity": "sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ=="
|
||||
},
|
||||
"node_modules/chart.js": {
|
||||
"version": "4.3.0",
|
||||
"resolved": "https://registry.npmjs.org/chart.js/-/chart.js-4.3.0.tgz",
|
||||
"integrity": "sha512-ynG0E79xGfMaV2xAHdbhwiPLczxnNNnasrmPEXriXsPJGjmhOBYzFVEsB65w2qMDz+CaBJJuJD0inE/ab/h36g==",
|
||||
"dependencies": {
|
||||
"@kurkle/color": "^0.3.0"
|
||||
},
|
||||
"engines": {
|
||||
"pnpm": ">=7"
|
||||
}
|
||||
},
|
||||
"node_modules/chokidar": {
|
||||
"version": "3.5.3",
|
||||
"resolved": "https://registry.npmmirror.com/chokidar/-/chokidar-3.5.3.tgz",
|
||||
@ -3884,6 +3902,15 @@
|
||||
"react-dom": "^16.8.5 || ^17.0.0 || ^18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-chartjs-2": {
|
||||
"version": "5.2.0",
|
||||
"resolved": "https://registry.npmjs.org/react-chartjs-2/-/react-chartjs-2-5.2.0.tgz",
|
||||
"integrity": "sha512-98iN5aguJyVSxp5U3CblRLH67J8gkfyGNbiK3c+l1QI/G4irHMPQw44aEPmjVag+YKTyQ260NcF82GTQ3bdscA==",
|
||||
"peerDependencies": {
|
||||
"chart.js": "^4.1.1",
|
||||
"react": "^16.8.0 || ^17.0.0 || ^18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-dom": {
|
||||
"version": "18.2.0",
|
||||
"resolved": "https://registry.npmmirror.com/react-dom/-/react-dom-18.2.0.tgz",
|
||||
|
@ -13,6 +13,7 @@
|
||||
"@fluentui/react-icons": "^2.0.201",
|
||||
"@microsoft/fetch-event-source": "^2.0.1",
|
||||
"@primer/octicons-react": "^19.1.0",
|
||||
"chart.js": "^4.3.0",
|
||||
"classnames": "^2.3.2",
|
||||
"github-markdown-css": "^5.2.0",
|
||||
"i18next": "^22.4.15",
|
||||
@ -20,6 +21,7 @@
|
||||
"mobx-react-lite": "^3.4.3",
|
||||
"react": "^18.2.0",
|
||||
"react-beautiful-dnd": "^13.1.1",
|
||||
"react-chartjs-2": "^5.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-i18next": "^12.2.2",
|
||||
"react-markdown": "^8.0.7",
|
||||
|
@ -189,5 +189,37 @@
|
||||
"user": "用户",
|
||||
"assistant": "AI",
|
||||
"system": "系统",
|
||||
"Regenerate": "重新生成"
|
||||
"Regenerate": "重新生成",
|
||||
"LoRA Finetune": "LoRA微调",
|
||||
"Command Stopped": "命令已终止",
|
||||
"Please convert data first.": "请先转换数据",
|
||||
"Ubuntu is not installed, do you want to install it?": "Ubuntu未安装,是否安装?",
|
||||
"Install Ubuntu": "安装Ubuntu",
|
||||
"Please install Ubuntu using Microsoft Store": "请用Microsoft Store安装Ubuntu",
|
||||
"WSL is not enabled, do you want to enable it?": "WSL未启用,是否启用?",
|
||||
"Enable WSL": "启用WSL",
|
||||
"After installation, please restart your computer to enable WSL": "安装完成后,请重启电脑以启用WSL",
|
||||
"Data Process": "数据处理",
|
||||
"Data Path": "数据路径",
|
||||
"Vocab Path": "词表路径",
|
||||
"Train Parameters": "训练参数",
|
||||
"Base Model": "基底模型",
|
||||
"LoRA Model": "LoRA模型",
|
||||
"Merge Model": "合并模型",
|
||||
"Devices": "显卡数量",
|
||||
"Gradient Checkpoint": "梯度检查点标志",
|
||||
"Context Length": "上下文长度",
|
||||
"Epoch Steps": "每轮训练步数",
|
||||
"Epoch Count": "训练轮次",
|
||||
"Epoch Begin": "起始轮次",
|
||||
"Epoch Save": "保存间隔轮次",
|
||||
"Learning Rate Init": "初始学习率",
|
||||
"Learning Rate Final": "最终学习率",
|
||||
"Micro Batch Size": "微批次大小",
|
||||
"Accumulate Gradient Batches": "梯度累积批次",
|
||||
"Warmup Steps": "学习率预热步数",
|
||||
"Pre-FFN": "前馈网络预处理",
|
||||
"None": "空",
|
||||
"Merge model successfully": "合并模型成功",
|
||||
"Convert Data successfully": "数据转换成功"
|
||||
}
|
@ -1,13 +1,542 @@
|
||||
import React, { FC } from 'react';
|
||||
import { Text } from '@fluentui/react-components';
|
||||
import React, { FC, ReactElement, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Button, Dropdown, Input, Option, Select, Switch, Tab, TabList } from '@fluentui/react-components';
|
||||
import {
|
||||
ConvertData,
|
||||
FileExists,
|
||||
MergeLora,
|
||||
OpenFileFolder,
|
||||
WslCommand,
|
||||
WslEnable,
|
||||
WslInstallUbuntu,
|
||||
WslIsEnabled,
|
||||
WslStart,
|
||||
WslStop
|
||||
} from '../../wailsjs/go/backend_golang/App';
|
||||
import { toast } from 'react-toastify';
|
||||
import commonStore from '../stores/commonStore';
|
||||
import { observer } from 'mobx-react-lite';
|
||||
import { SelectTabEventHandler } from '@fluentui/react-tabs';
|
||||
import { refreshLocalModels, toastWithButton } from '../utils';
|
||||
import { Section } from '../components/Section';
|
||||
import { Labeled } from '../components/Labeled';
|
||||
import { ToolTipButton } from '../components/ToolTipButton';
|
||||
import { DataUsageSettings20Regular, Folder20Regular } from '@fluentui/react-icons';
|
||||
import { useNavigate } from 'react-router';
|
||||
import { Precision } from './Configs';
|
||||
import {
|
||||
CategoryScale,
|
||||
Chart as ChartJS,
|
||||
Legend,
|
||||
LinearScale,
|
||||
LineElement,
|
||||
PointElement,
|
||||
Title,
|
||||
Tooltip
|
||||
} from 'chart.js';
|
||||
import { Line } from 'react-chartjs-2';
|
||||
import { ChartJSOrUndefined } from 'react-chartjs-2/dist/types';
|
||||
|
||||
ChartJS.register(
|
||||
CategoryScale,
|
||||
LinearScale,
|
||||
PointElement,
|
||||
LineElement,
|
||||
Tooltip,
|
||||
Title,
|
||||
Legend
|
||||
);
|
||||
|
||||
const parseLossData = (data: string) => {
|
||||
const regex = /Epoch (\d+):\s+(\d+%)\|[\s\S]*\| (\d+)\/(\d+) \[(\d+:\d+)<(\d+:\d+),\s+(\d+.\d+it\/s), loss=(\d+.\d+),[\s\S]*\]/g;
|
||||
const matches = Array.from(data.matchAll(regex));
|
||||
if (matches.length === 0)
|
||||
return;
|
||||
const lastMatch = matches[matches.length - 1];
|
||||
const epoch = parseInt(lastMatch[1]);
|
||||
const loss = parseFloat(lastMatch[8]);
|
||||
commonStore.setChartTitle(`Epoch ${epoch}: ${lastMatch[2]} - ${lastMatch[3]}/${lastMatch[4]} - ${lastMatch[5]}/${lastMatch[6]} - ${lastMatch[7]} Loss=${loss}`);
|
||||
addLossDataToChart(epoch, loss);
|
||||
};
|
||||
|
||||
let chartLine: ChartJSOrUndefined<'line', (number | null)[], string>;
|
||||
|
||||
const addLossDataToChart = (epoch: number, loss: number) => {
|
||||
const epochIndex = commonStore.chartData.labels!.findIndex(l => l.includes(epoch.toString()));
|
||||
if (epochIndex === -1) {
|
||||
if (epoch === 0) {
|
||||
commonStore.chartData.labels!.push('Init');
|
||||
commonStore.chartData.datasets[0].data = [...commonStore.chartData.datasets[0].data, loss];
|
||||
}
|
||||
commonStore.chartData.labels!.push('Epoch ' + epoch.toString());
|
||||
commonStore.chartData.datasets[0].data = [...commonStore.chartData.datasets[0].data, loss];
|
||||
} else {
|
||||
if (chartLine) {
|
||||
const newData = [...commonStore.chartData.datasets[0].data];
|
||||
newData[epochIndex] = loss;
|
||||
chartLine.data.datasets[0].data = newData;
|
||||
chartLine.update();
|
||||
}
|
||||
}
|
||||
commonStore.setChartData(commonStore.chartData);
|
||||
};
|
||||
|
||||
export type DataProcessParameters = {
|
||||
dataPath: string;
|
||||
vocabPath: string;
|
||||
}
|
||||
|
||||
export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'fp32' | 'tf32';
|
||||
|
||||
export type LoraFinetuneParameters = {
|
||||
baseModel: string;
|
||||
ctxLen: number;
|
||||
epochSteps: number;
|
||||
epochCount: number;
|
||||
epochBegin: number;
|
||||
epochSave: number;
|
||||
microBsz: number;
|
||||
accumGradBatches: number;
|
||||
preFfn: boolean;
|
||||
headQk: boolean;
|
||||
lrInit: string;
|
||||
lrFinal: string;
|
||||
warmupSteps: number;
|
||||
beta1: number;
|
||||
beta2: number;
|
||||
adamEps: string;
|
||||
devices: number;
|
||||
precision: LoraFinetunePrecision;
|
||||
gradCp: boolean;
|
||||
loraR: number;
|
||||
loraAlpha: number;
|
||||
loraDropout: number;
|
||||
loraLoad: string
|
||||
}
|
||||
|
||||
const loraFinetuneParametersOptions: Array<[key: keyof LoraFinetuneParameters, type: string, name: string]> = [
|
||||
['devices', 'number', 'Devices'],
|
||||
['precision', 'LoraFinetunePrecision', 'Precision'],
|
||||
['gradCp', 'boolean', 'Gradient Checkpoint'],
|
||||
['ctxLen', 'number', 'Context Length'],
|
||||
['epochSteps', 'number', 'Epoch Steps'],
|
||||
['epochCount', 'number', 'Epoch Count'],
|
||||
['epochBegin', 'number', 'Epoch Begin'],
|
||||
['epochSave', 'number', 'Epoch Save'],
|
||||
['lrInit', 'string', 'Learning Rate Init'],
|
||||
['lrFinal', 'string', 'Learning Rate Final'],
|
||||
['microBsz', 'number', 'Micro Batch Size'],
|
||||
['accumGradBatches', 'number', 'Accumulate Gradient Batches'],
|
||||
['warmupSteps', 'number', 'Warmup Steps'],
|
||||
['adamEps', 'string', 'Adam Epsilon'],
|
||||
['beta1', 'number', 'Beta 1'],
|
||||
['beta2', 'number', 'Beta 2'],
|
||||
['loraR', 'number', 'LoRA R'],
|
||||
['loraAlpha', 'number', 'LoRA Alpha'],
|
||||
['loraDropout', 'number', 'LoRA Dropout'],
|
||||
['beta1', 'any', ''],
|
||||
['preFfn', 'boolean', 'Pre-FFN'],
|
||||
['headQk', 'boolean', 'Head QK']
|
||||
];
|
||||
|
||||
export const wslHandler = (data: string) => {
|
||||
if (data) {
|
||||
addWslMessage(data);
|
||||
parseLossData(data);
|
||||
}
|
||||
};
|
||||
|
||||
const addWslMessage = (message: string) => {
|
||||
const newData = commonStore.wslStdout + '\n' + message;
|
||||
let lines = newData.split('\n');
|
||||
const result = lines.slice(-100).join('\n');
|
||||
commonStore.setWslStdout(result);
|
||||
};
|
||||
|
||||
const TerminalDisplay: FC = observer(() => {
|
||||
const bodyRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const scrollToBottom = () => {
|
||||
if (bodyRef.current)
|
||||
bodyRef.current.scrollTop = bodyRef.current.scrollHeight;
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
});
|
||||
|
||||
return (
|
||||
<div ref={bodyRef} className="grow overflow-x-hidden overflow-y-auto border-gray-500 border-2 rounded-md">
|
||||
<div className="whitespace-pre-line">
|
||||
{commonStore.wslStdout}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
const Terminal: FC = observer(() => {
|
||||
const { t } = useTranslation();
|
||||
const [input, setInput] = useState('');
|
||||
|
||||
const handleKeyDown = (e: any) => {
|
||||
e.stopPropagation();
|
||||
if (e.keyCode === 13) {
|
||||
e.preventDefault();
|
||||
if (!input) return;
|
||||
|
||||
WslStart().then(() => {
|
||||
addWslMessage('WSL> ' + input);
|
||||
setInput('');
|
||||
WslCommand(input).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full gap-4">
|
||||
<TerminalDisplay />
|
||||
<div className="flex gap-2 items-center">
|
||||
WSL:
|
||||
<Input className="grow" value={input} onChange={(e) => {
|
||||
setInput(e.target.value);
|
||||
}} onKeyDown={handleKeyDown}></Input>
|
||||
<Button onClick={() => {
|
||||
WslStop().then(() => {
|
||||
toast(t('Command Stopped'), { type: 'success' });
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>
|
||||
{t('Stop')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
const LoraFinetune: FC = observer(() => {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const chartRef = useRef<ChartJSOrUndefined<'line', (number | null)[], string>>(null);
|
||||
|
||||
const dataParams = commonStore.dataProcessParams;
|
||||
const loraParams = commonStore.loraFinetuneParams;
|
||||
|
||||
if (chartRef.current)
|
||||
chartLine = chartRef.current;
|
||||
|
||||
const setDataParams = (newParams: Partial<DataProcessParameters>) => {
|
||||
commonStore.setDataProcessParams({
|
||||
...dataParams,
|
||||
...newParams
|
||||
});
|
||||
};
|
||||
|
||||
const setLoraParams = (newParams: Partial<LoraFinetuneParameters>) => {
|
||||
commonStore.setLoraFinetuneParameters({
|
||||
...loraParams,
|
||||
...newParams
|
||||
});
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (loraParams.baseModel === '')
|
||||
setLoraParams({
|
||||
baseModel: commonStore.modelSourceList.find(m => m.isComplete)?.name || ''
|
||||
});
|
||||
}, []);
|
||||
|
||||
const StartLoraFinetune = () => {
|
||||
WslIsEnabled().then(() => {
|
||||
WslStart().then(async () => {
|
||||
const convertedDataPath = `./finetune/json2binidx_tool/data/${dataParams.dataPath.split('/').pop()!.split('.')[0]}_text_document`;
|
||||
if (!await FileExists(convertedDataPath + '.idx')) {
|
||||
toast(t('Please convert data first.'), { type: 'error' });
|
||||
return;
|
||||
}
|
||||
|
||||
commonStore.setChartData({
|
||||
labels: [],
|
||||
datasets: [
|
||||
{
|
||||
label: 'Loss',
|
||||
data: [],
|
||||
borderColor: 'rgb(53, 162, 235)',
|
||||
backgroundColor: 'rgba(53, 162, 235, 0.5)'
|
||||
}
|
||||
]
|
||||
});
|
||||
WslCommand(`export cnMirror=${commonStore.settings.cnMirror ? '1' : '0'} ` +
|
||||
`&& export loadModel=models/${loraParams.baseModel} ` +
|
||||
`&& chmod +x finetune/install-wsl-dep-and-train.sh && ./finetune/install-wsl-dep-and-train.sh ` +
|
||||
(loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') +
|
||||
(loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') +
|
||||
`--data_file ${convertedDataPath} ` +
|
||||
`--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` +
|
||||
`--ctx_len ${loraParams.ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` +
|
||||
`--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` +
|
||||
`--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` +
|
||||
`--pre_ffn ${loraParams.preFfn ? '1' : '0'} --head_qk ${loraParams.headQk ? '1' : '0'} --lr_init ${loraParams.lrInit} --lr_final ${loraParams.lrFinal} ` +
|
||||
`--warmup_steps ${loraParams.warmupSteps} ` +
|
||||
`--beta1 ${loraParams.beta1} --beta2 ${loraParams.beta2} --adam_eps ${loraParams.adamEps} ` +
|
||||
`--devices ${loraParams.devices} --precision ${loraParams.precision} ` +
|
||||
`--grad_cp ${loraParams.gradCp ? '1' : '0'} ` +
|
||||
`--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}).catch(e => {
|
||||
const msg = e.message || e;
|
||||
if (msg === 'ubuntu not found') {
|
||||
toastWithButton(t('Ubuntu is not installed, do you want to install it?'), t('Install Ubuntu'), () => {
|
||||
WslInstallUbuntu().then(() => {
|
||||
toast(t('Please install Ubuntu using Microsoft Store'), { type: 'info', autoClose: 6000 });
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
}).catch(e => {
|
||||
const msg = e.message || e;
|
||||
|
||||
const enableWsl = (forceMode: boolean) => {
|
||||
toastWithButton(t('WSL is not enabled, do you want to enable it?'), t('Enable WSL'), () => {
|
||||
WslEnable(forceMode).then(() => {
|
||||
toast(t('After installation, please restart your computer to enable WSL'), {
|
||||
type: 'info',
|
||||
autoClose: false
|
||||
});
|
||||
}).catch(e => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
if (msg === 'wsl is not enabled') {
|
||||
enableWsl(false);
|
||||
} else if (msg.includes('wsl.state: The system cannot find the file')) {
|
||||
enableWsl(true);
|
||||
} else {
|
||||
toast(msg, { type: 'error' });
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full w-full gap-2">
|
||||
{(commonStore.wslStdout.length > 0 || commonStore.chartData.labels!.length !== 0) &&
|
||||
<div className="flex" style={{ height: '35%' }}>
|
||||
{commonStore.wslStdout.length > 0 && commonStore.chartData.labels!.length === 0 && <TerminalDisplay />}
|
||||
{commonStore.chartData.labels!.length !== 0 &&
|
||||
<Line ref={chartRef} data={commonStore.chartData} options={{
|
||||
responsive: true,
|
||||
showLine: true,
|
||||
plugins: {
|
||||
legend: {
|
||||
position: 'right',
|
||||
align: 'start'
|
||||
},
|
||||
title: {
|
||||
display: true,
|
||||
text: commonStore.chartTitle
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
y: {
|
||||
beginAtZero: true
|
||||
}
|
||||
},
|
||||
maintainAspectRatio: false
|
||||
}} style={{ width: '100%' }} />}
|
||||
</div>
|
||||
}
|
||||
<div>
|
||||
<Section
|
||||
title={t('Data Process')}
|
||||
content={
|
||||
<div className="flex flex-col gap-2">
|
||||
<Labeled flex label={t('Data Path')}
|
||||
content={
|
||||
<div className="grow flex gap-2">
|
||||
<Input className="grow ml-2" value={dataParams.dataPath}
|
||||
onChange={(e, data) => {
|
||||
setDataParams({ dataPath: data.value });
|
||||
}} />
|
||||
<ToolTipButton desc={t('Open Folder')} icon={<Folder20Regular />} onClick={() => {
|
||||
OpenFileFolder(dataParams.dataPath, false);
|
||||
}} />
|
||||
</div>
|
||||
} />
|
||||
<div className="flex gap-2 items-center">
|
||||
{t('Vocab Path')}
|
||||
<Input className="grow" style={{ minWidth: 0 }} value={dataParams.vocabPath}
|
||||
onChange={(e, data) => {
|
||||
setDataParams({ vocabPath: data.value });
|
||||
}} />
|
||||
<Button appearance="secondary" size="large" onClick={() => {
|
||||
ConvertData(commonStore.settings.customPythonPath, dataParams.dataPath,
|
||||
'./finetune/json2binidx_tool/data/' + dataParams.dataPath.split('/').pop()!.split('.')[0],
|
||||
dataParams.vocabPath).then(() => {
|
||||
toast(t('Convert Data successfully'), { type: 'success' });
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>{t('Convert')}</Button>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<Section
|
||||
title={t('Train Parameters')}
|
||||
content={
|
||||
<div className="grid grid-cols-1 sm:grid-cols-2 gap-2">
|
||||
<div className="flex gap-2 items-center">
|
||||
{t('Base Model')}
|
||||
<Select style={{ minWidth: 0 }} className="grow"
|
||||
value={loraParams.baseModel}
|
||||
onChange={(e, data) => {
|
||||
setLoraParams({
|
||||
baseModel: data.value
|
||||
});
|
||||
}}>
|
||||
{commonStore.modelSourceList.map((modelItem, index) =>
|
||||
modelItem.isComplete && <option key={index} value={modelItem.name}>{modelItem.name}</option>
|
||||
)}
|
||||
</Select>
|
||||
<ToolTipButton desc={t('Manage Models')} icon={<DataUsageSettings20Regular />} onClick={() => {
|
||||
navigate({ pathname: '/models' });
|
||||
}} />
|
||||
</div>
|
||||
<div className="flex gap-2 items-center">
|
||||
{t('LoRA Model')}
|
||||
<Select style={{ minWidth: 0 }} className="grow"
|
||||
value={loraParams.loraLoad}
|
||||
onChange={(e, data) => {
|
||||
setLoraParams({
|
||||
loraLoad: data.value
|
||||
});
|
||||
}}>
|
||||
<option value="">{t('None')}</option>
|
||||
{commonStore.loraModels.map((name, index) =>
|
||||
<option key={index} value={name}>{name}</option>
|
||||
)}
|
||||
</Select>
|
||||
<Button onClick={() => {
|
||||
MergeLora(commonStore.settings.customPythonPath, true, loraParams.loraAlpha,
|
||||
'models/' + loraParams.baseModel, 'lora-models/' + loraParams.loraLoad,
|
||||
`models/${loraParams.baseModel}-LoRA-${loraParams.loraLoad}`).then(() => {
|
||||
toast(t('Merge model successfully'), { type: 'success' });
|
||||
refreshLocalModels({ models: commonStore.modelSourceList }, false);
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>{t('Merge Model')}</Button>
|
||||
</div>
|
||||
{
|
||||
loraFinetuneParametersOptions.map(([key, type, name], index) => {
|
||||
return (
|
||||
<Labeled key={index} label={t(name)} content={
|
||||
type === 'number' ?
|
||||
<Input type="number" className="grow" value={loraParams[key].toString()}
|
||||
onChange={(e, data) => {
|
||||
setLoraParams({
|
||||
[key]: Number(data.value)
|
||||
});
|
||||
}} /> :
|
||||
type === 'boolean' ?
|
||||
<Switch className="grow" checked={loraParams[key] as boolean}
|
||||
onChange={(e, data) => {
|
||||
setLoraParams({
|
||||
[key]: data.checked
|
||||
});
|
||||
}} /> :
|
||||
type === 'string' ?
|
||||
<Input className="grow" value={loraParams[key].toString()}
|
||||
onChange={(e, data) => {
|
||||
setLoraParams({
|
||||
[key]: data.value
|
||||
});
|
||||
}} /> :
|
||||
type === 'LoraFinetunePrecision' ?
|
||||
<Dropdown style={{ minWidth: 0 }} className="grow"
|
||||
value={loraParams[key].toString()}
|
||||
selectedOptions={[loraParams[key].toString()]}
|
||||
onOptionSelect={(_, data) => {
|
||||
if (data.optionText) {
|
||||
setLoraParams({
|
||||
precision: data.optionText as LoraFinetunePrecision
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Option>bf16</Option>
|
||||
<Option>fp16</Option>
|
||||
<Option>fp32</Option>
|
||||
<Option>tf32</Option>
|
||||
</Dropdown>
|
||||
: <div />
|
||||
} />
|
||||
);
|
||||
})
|
||||
}
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
<div className="grow" />
|
||||
<div className="flex gap-2">
|
||||
<div className="grow" />
|
||||
<Button appearance="secondary" size="large" onClick={() => {
|
||||
WslStop().then(() => {
|
||||
toast(t('Command Stopped'), { type: 'success' });
|
||||
}).catch((e) => {
|
||||
toast((e.message || e), { type: 'error' });
|
||||
});
|
||||
}}>{t('Stop')}</Button>
|
||||
<Button appearance="primary" size="large" onClick={StartLoraFinetune}>{t('Train')}</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
type TrainNavigationItem = {
|
||||
element: ReactElement;
|
||||
};
|
||||
|
||||
const pages: { [label: string]: TrainNavigationItem } = {
|
||||
'LoRA Finetune': {
|
||||
element: <LoraFinetune />
|
||||
},
|
||||
WSL: {
|
||||
element: <Terminal />
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
export const Train: FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const [tab, setTab] = useState('LoRA Finetune');
|
||||
|
||||
return (
|
||||
<div className="flex flex-col box-border gap-5 p-2">
|
||||
<Text size={600}>{t('In Development')}</Text>
|
||||
const selectTab: SelectTabEventHandler = (e, data) =>
|
||||
typeof data.value === 'string' ? setTab(data.value) : null;
|
||||
|
||||
return <div className="flex flex-col gap-2 w-full h-full">
|
||||
<TabList
|
||||
size="small"
|
||||
appearance="subtle"
|
||||
selectedValue={tab}
|
||||
onTabSelect={selectTab}
|
||||
>
|
||||
{Object.entries(pages).map(([label]) => (
|
||||
<Tab key={label} value={label}>
|
||||
{t(label)}
|
||||
</Tab>
|
||||
))}
|
||||
</TabList>
|
||||
<div className="grow overflow-hidden">
|
||||
{pages[tab].element}
|
||||
</div>
|
||||
);
|
||||
</div>;
|
||||
};
|
||||
|
@ -1,11 +1,12 @@
|
||||
import commonStore, { Platform } from './stores/commonStore';
|
||||
import { GetPlatform, ReadJson } from '../wailsjs/go/backend_golang/App';
|
||||
import { GetPlatform, ListDirFiles, ReadJson } from '../wailsjs/go/backend_golang/App';
|
||||
import { Cache, checkUpdate, downloadProgramFiles, LocalConfig, refreshModels } from './utils';
|
||||
import { getStatus } from './apis';
|
||||
import { EventsOn } from '../wailsjs/runtime';
|
||||
import manifest from '../../manifest.json';
|
||||
import { defaultModelConfigs, defaultModelConfigsMac } from './pages/defaultModelConfigs';
|
||||
import { Preset } from './pages/PresetsManager/PresetsButton';
|
||||
import { wslHandler } from './pages/Train';
|
||||
|
||||
export async function startup() {
|
||||
downloadProgramFiles();
|
||||
@ -13,9 +14,14 @@ export async function startup() {
|
||||
if (data)
|
||||
commonStore.setDownloadList(data);
|
||||
});
|
||||
EventsOn('wsl', wslHandler);
|
||||
EventsOn('wslerr', (e) => {
|
||||
console.log(e);
|
||||
});
|
||||
initLoraModels();
|
||||
|
||||
initPresets();
|
||||
|
||||
|
||||
await GetPlatform().then(p => commonStore.setPlatform(p as Platform));
|
||||
await initConfig();
|
||||
|
||||
@ -50,6 +56,9 @@ async function initConfig() {
|
||||
if (configData.settings)
|
||||
commonStore.setSettings(configData.settings, false);
|
||||
|
||||
if (configData.dataProcessParams)
|
||||
commonStore.setDataProcessParams(configData.dataProcessParams, false);
|
||||
|
||||
if (configData.modelConfigs && Array.isArray(configData.modelConfigs))
|
||||
commonStore.setModelConfigs(configData.modelConfigs, false);
|
||||
else throw new Error('Invalid config.json');
|
||||
@ -76,3 +85,24 @@ async function initPresets() {
|
||||
}).catch(() => {
|
||||
});
|
||||
}
|
||||
|
||||
async function initLoraModels() {
|
||||
const refreshLoraModels = () => {
|
||||
ListDirFiles('lora-models').then((data) => {
|
||||
if (!data) return;
|
||||
const loraModels = [];
|
||||
for (const f of data) {
|
||||
if (!f.isDir && f.name.endsWith('.pth')) {
|
||||
loraModels.push(f.name);
|
||||
}
|
||||
}
|
||||
commonStore.setLoraModels(loraModels);
|
||||
});
|
||||
};
|
||||
|
||||
refreshLoraModels();
|
||||
EventsOn('fsnotify', (data: string) => {
|
||||
if (data.includes('lora-models'))
|
||||
refreshLoraModels();
|
||||
});
|
||||
}
|
||||
|
@ -14,6 +14,8 @@ import { CompletionPreset } from '../pages/Completion';
|
||||
import { defaultModelConfigs, defaultModelConfigsMac } from '../pages/defaultModelConfigs';
|
||||
import commonStore from './commonStore';
|
||||
import { Preset } from '../pages/PresetsManager/PresetsButton';
|
||||
import { DataProcessParameters, LoraFinetuneParameters } from '../pages/Train';
|
||||
import { ChartData } from 'chart.js';
|
||||
|
||||
export enum ModelStatus {
|
||||
Offline,
|
||||
@ -30,6 +32,8 @@ export type Status = {
|
||||
|
||||
export type Platform = 'windows' | 'darwin' | 'linux';
|
||||
|
||||
const labels = ['January', 'February', 'March', 'April', 'May', 'June', 'July'];
|
||||
|
||||
class CommonStore {
|
||||
// global
|
||||
status: Status = {
|
||||
@ -62,6 +66,40 @@ class CommonStore {
|
||||
// downloads
|
||||
downloadList: DownloadStatus[] = [];
|
||||
lastUnfinishedModelDownloads: DownloadStatus[] = [];
|
||||
// train
|
||||
wslStdout: string = '';
|
||||
chartTitle: string = '';
|
||||
chartData: ChartData<'line', (number | null)[], string> = { labels: [], datasets: [] };
|
||||
loraModels: string[] = [];
|
||||
dataProcessParams: DataProcessParameters = {
|
||||
dataPath: 'finetune/data/sample.jsonl',
|
||||
vocabPath: 'backend-python/rwkv_pip/rwkv_vocab_v20230424.txt'
|
||||
};
|
||||
loraFinetuneParams: LoraFinetuneParameters = {
|
||||
baseModel: '',
|
||||
ctxLen: 1024,
|
||||
epochSteps: 1000,
|
||||
epochCount: 20,
|
||||
epochBegin: 0,
|
||||
epochSave: 5,
|
||||
microBsz: 1,
|
||||
accumGradBatches: 8,
|
||||
preFfn: false,
|
||||
headQk: false,
|
||||
lrInit: '5e-5',
|
||||
lrFinal: '5e-5',
|
||||
warmupSteps: 0,
|
||||
beta1: 0.9,
|
||||
beta2: 0.999,
|
||||
adamEps: '1e-8',
|
||||
devices: 1,
|
||||
precision: 'bf16',
|
||||
gradCp: false,
|
||||
loraR: 8,
|
||||
loraAlpha: 32,
|
||||
loraDropout: 0.01,
|
||||
loraLoad: ''
|
||||
};
|
||||
// settings
|
||||
advancedCollapsed: boolean = true;
|
||||
settings: SettingsType = {
|
||||
@ -228,6 +266,34 @@ class CommonStore {
|
||||
setCompletionSubmittedPrompt(value: string) {
|
||||
this.completionSubmittedPrompt = value;
|
||||
}
|
||||
|
||||
setWslStdout(value: string) {
|
||||
this.wslStdout = value;
|
||||
}
|
||||
|
||||
setDataProcessParams(value: DataProcessParameters, saveConfig: boolean = true) {
|
||||
this.dataProcessParams = value;
|
||||
if (saveConfig)
|
||||
saveConfigs();
|
||||
}
|
||||
|
||||
setLoraFinetuneParameters(value: LoraFinetuneParameters, saveConfig: boolean = true) {
|
||||
this.loraFinetuneParams = value;
|
||||
if (saveConfig)
|
||||
saveConfigs();
|
||||
}
|
||||
|
||||
setChartTitle(value: string) {
|
||||
this.chartTitle = value;
|
||||
}
|
||||
|
||||
setChartData(value: ChartData<'line', (number | null)[], string>) {
|
||||
this.chartData = value;
|
||||
}
|
||||
|
||||
setLoraModels(value: string[]) {
|
||||
this.loraModels = value;
|
||||
}
|
||||
}
|
||||
|
||||
export default new CommonStore();
|
@ -17,6 +17,7 @@ import { Language, Languages, SettingsType } from '../pages/Settings';
|
||||
import { ModelSourceItem } from '../pages/Models';
|
||||
import { ModelConfig, ModelParameters } from '../pages/Configs';
|
||||
import { DownloadStatus } from '../pages/Downloads';
|
||||
import { DataProcessParameters, LoraFinetuneParameters } from '../pages/Train';
|
||||
|
||||
export type Cache = {
|
||||
version: string
|
||||
@ -28,7 +29,9 @@ export type LocalConfig = {
|
||||
modelSourceManifestList: string
|
||||
currentModelConfigIndex: number
|
||||
modelConfigs: ModelConfig[]
|
||||
settings: SettingsType
|
||||
settings: SettingsType,
|
||||
dataProcessParams: DataProcessParameters,
|
||||
loraFinetuneParams: LoraFinetuneParameters
|
||||
}
|
||||
|
||||
export async function refreshBuiltInModels(readCache: boolean = false) {
|
||||
@ -194,7 +197,9 @@ export const saveConfigs = async () => {
|
||||
modelSourceManifestList: commonStore.modelSourceManifestList,
|
||||
currentModelConfigIndex: commonStore.currentModelConfigIndex,
|
||||
modelConfigs: commonStore.modelConfigs,
|
||||
settings: commonStore.settings
|
||||
settings: commonStore.settings,
|
||||
dataProcessParams: commonStore.dataProcessParams,
|
||||
loraFinetuneParams: commonStore.loraFinetuneParams
|
||||
};
|
||||
return SaveJson('config.json', data);
|
||||
};
|
||||
|
16
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
16
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
@ -6,6 +6,8 @@ export function AddToDownloadList(arg1:string,arg2:string):Promise<void>;
|
||||
|
||||
export function ContinueDownload(arg1:string):Promise<void>;
|
||||
|
||||
export function ConvertData(arg1:string,arg2:string,arg3:string,arg4:string):Promise<string>;
|
||||
|
||||
export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Promise<string>;
|
||||
|
||||
export function CopyFile(arg1:string,arg2:string):Promise<void>;
|
||||
@ -24,6 +26,8 @@ export function InstallPyDep(arg1:string,arg2:boolean):Promise<string>;
|
||||
|
||||
export function ListDirFiles(arg1:string):Promise<Array<backend_golang.FileInfo>>;
|
||||
|
||||
export function MergeLora(arg1:string,arg2:boolean,arg3:number,arg4:string,arg5:string,arg6:string):Promise<string>;
|
||||
|
||||
export function OpenFileFolder(arg1:string,arg2:boolean):Promise<void>;
|
||||
|
||||
export function OpenSaveFileDialog(arg1:string,arg2:string,arg3:string):Promise<string>;
|
||||
@ -41,3 +45,15 @@ export function SaveJson(arg1:string,arg2:any):Promise<void>;
|
||||
export function StartServer(arg1:string,arg2:number,arg3:string):Promise<string>;
|
||||
|
||||
export function UpdateApp(arg1:string):Promise<boolean>;
|
||||
|
||||
export function WslCommand(arg1:string):Promise<void>;
|
||||
|
||||
export function WslEnable(arg1:boolean):Promise<void>;
|
||||
|
||||
export function WslInstallUbuntu():Promise<void>;
|
||||
|
||||
export function WslIsEnabled():Promise<void>;
|
||||
|
||||
export function WslStart():Promise<void>;
|
||||
|
||||
export function WslStop():Promise<void>;
|
||||
|
32
frontend/wailsjs/go/backend_golang/App.js
generated
32
frontend/wailsjs/go/backend_golang/App.js
generated
@ -10,6 +10,10 @@ export function ContinueDownload(arg1) {
|
||||
return window['go']['backend_golang']['App']['ContinueDownload'](arg1);
|
||||
}
|
||||
|
||||
export function ConvertData(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['backend_golang']['App']['ConvertData'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
export function ConvertModel(arg1, arg2, arg3, arg4) {
|
||||
return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3, arg4);
|
||||
}
|
||||
@ -46,6 +50,10 @@ export function ListDirFiles(arg1) {
|
||||
return window['go']['backend_golang']['App']['ListDirFiles'](arg1);
|
||||
}
|
||||
|
||||
export function MergeLora(arg1, arg2, arg3, arg4, arg5, arg6) {
|
||||
return window['go']['backend_golang']['App']['MergeLora'](arg1, arg2, arg3, arg4, arg5, arg6);
|
||||
}
|
||||
|
||||
export function OpenFileFolder(arg1, arg2) {
|
||||
return window['go']['backend_golang']['App']['OpenFileFolder'](arg1, arg2);
|
||||
}
|
||||
@ -81,3 +89,27 @@ export function StartServer(arg1, arg2, arg3) {
|
||||
export function UpdateApp(arg1) {
|
||||
return window['go']['backend_golang']['App']['UpdateApp'](arg1);
|
||||
}
|
||||
|
||||
export function WslCommand(arg1) {
|
||||
return window['go']['backend_golang']['App']['WslCommand'](arg1);
|
||||
}
|
||||
|
||||
export function WslEnable(arg1) {
|
||||
return window['go']['backend_golang']['App']['WslEnable'](arg1);
|
||||
}
|
||||
|
||||
export function WslInstallUbuntu() {
|
||||
return window['go']['backend_golang']['App']['WslInstallUbuntu']();
|
||||
}
|
||||
|
||||
export function WslIsEnabled() {
|
||||
return window['go']['backend_golang']['App']['WslIsEnabled']();
|
||||
}
|
||||
|
||||
export function WslStart() {
|
||||
return window['go']['backend_golang']['App']['WslStart']();
|
||||
}
|
||||
|
||||
export function WslStop() {
|
||||
return window['go']['backend_golang']['App']['WslStop']();
|
||||
}
|
||||
|
7
go.mod
7
go.mod
@ -5,12 +5,14 @@ go 1.20
|
||||
require (
|
||||
github.com/cavaliergopher/grab/v3 v3.0.1
|
||||
github.com/minio/selfupdate v0.6.0
|
||||
github.com/ubuntu/gowsl v0.0.0-20230615094051-94945650cc1e
|
||||
github.com/wailsapp/wails/v2 v2.5.1
|
||||
)
|
||||
|
||||
require (
|
||||
aead.dev/minisign v0.2.0 // indirect
|
||||
github.com/bep/debounce v1.2.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.6.0
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect
|
||||
@ -21,17 +23,20 @@ require (
|
||||
github.com/leaanthony/slicer v1.6.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.18 // indirect
|
||||
github.com/nyaosorg/go-windows-su v0.2.1
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/rivo/uniseg v0.4.4 // indirect
|
||||
github.com/samber/lo v1.38.1 // indirect
|
||||
github.com/sirupsen/logrus v1.9.0 // indirect
|
||||
github.com/tkrajina/go-reflector v0.5.6 // indirect
|
||||
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasttemplate v1.2.2 // indirect
|
||||
github.com/wailsapp/mimetype v1.4.1 // indirect
|
||||
golang.org/x/crypto v0.9.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect
|
||||
golang.org/x/net v0.10.0 // indirect
|
||||
golang.org/x/sys v0.8.0 // indirect
|
||||
golang.org/x/sys v0.9.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
)
|
||||
|
19
go.sum
19
go.sum
@ -1,5 +1,6 @@
|
||||
aead.dev/minisign v0.2.0 h1:kAWrq/hBRu4AARY6AlciO83xhNnW9UaC8YipS2uhLPk=
|
||||
aead.dev/minisign v0.2.0/go.mod h1:zdq6LdSd9TbuSxchxwhpA9zEb9YXcVGoE8JakuiGaIQ=
|
||||
github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg=
|
||||
github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY=
|
||||
github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0=
|
||||
github.com/cavaliergopher/grab/v3 v3.0.1 h1:4z7TkBfmPjmLAAmkkAZNX/6QJ1nNFdv3SdIHXju0Fr4=
|
||||
@ -7,6 +8,8 @@ github.com/cavaliergopher/grab/v3 v3.0.1/go.mod h1:1U/KNnD+Ft6JJiYoYBAimKH2XrYpt
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
@ -37,6 +40,8 @@ github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp9
|
||||
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU=
|
||||
github.com/minio/selfupdate v0.6.0/go.mod h1:bO02GTIPCMQFTEvE5h4DjYB58bCoZ35XLeBf0buTDdM=
|
||||
github.com/nyaosorg/go-windows-su v0.2.1 h1:5V0XavLyjOqPUp7psxxCvBISaneU4XmFPSMlejSl5sc=
|
||||
github.com/nyaosorg/go-windows-su v0.2.1/go.mod h1:fWKxSCXwGuDuW6ne0kLp/Cj0joXNDDw01G3LseQJYS0=
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU=
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
@ -48,11 +53,17 @@ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
|
||||
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
|
||||
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
|
||||
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tkrajina/go-reflector v0.5.6 h1:hKQ0gyocG7vgMD2M3dRlYN6WBBOmdoOzJ6njQSepKdE=
|
||||
github.com/tkrajina/go-reflector v0.5.6/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
|
||||
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 h1:XQpsQG5lqRJlx4mUVHcJvyyc1rdTI9nHvwrdfcuy8aM=
|
||||
github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117/go.mod h1:mx0TjbqsaDD9DUT5gA1s3hw47U6RIbbIBfvGzR85K0g=
|
||||
github.com/ubuntu/gowsl v0.0.0-20230615094051-94945650cc1e h1:5hJ4Z9ISvbDUWL7TDvfoYp0bXsaX42WjAUJzyZ8NMCI=
|
||||
github.com/ubuntu/gowsl v0.0.0-20230615094051-94945650cc1e/go.mod h1:tu2rOgQGt6bZce1OE8G75Ca8+NvNmTNOvplLolr326I=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ=
|
||||
@ -86,10 +97,12 @@ golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
|
||||
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
10
main.go
10
main.go
@ -26,12 +26,22 @@ var cyacInfo embed.FS
|
||||
//go:embed backend-python
|
||||
var py embed.FS
|
||||
|
||||
//go:embed finetune
|
||||
var finetune embed.FS
|
||||
|
||||
func main() {
|
||||
if buildInfo, ok := debug.ReadBuildInfo(); !ok || strings.Contains(buildInfo.String(), "-ldflags") {
|
||||
backend.CopyEmbed(cyac)
|
||||
backend.CopyEmbed(cyacInfo)
|
||||
backend.CopyEmbed(py)
|
||||
backend.CopyEmbed(finetune)
|
||||
os.Mkdir("models", os.ModePerm)
|
||||
os.Mkdir("lora-models", os.ModePerm)
|
||||
}
|
||||
|
||||
f, err := os.Create("lora-models/train_log.txt")
|
||||
if err == nil {
|
||||
f.Close()
|
||||
}
|
||||
|
||||
// Create an instance of the app structure
|
||||
|
Loading…
Reference in New Issue
Block a user