diff --git a/.gitattributes b/.gitattributes index f61aba1..d2728ba 100644 --- a/.gitattributes +++ b/.gitattributes @@ -3,4 +3,6 @@ backend-python/wkv_cuda_utils/** linguist-vendored backend-python/get-pip.py linguist-vendored backend-python/convert_model.py linguist-vendored build/** linguist-vendored +finetune/lora/** linguist-vendored +finetune/json2binidx_tool/** linguist-vendored frontend/wailsjs/** linguist-generated \ No newline at end of file diff --git a/.gitignore b/.gitignore index 9f182cc..fa7a04b 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,6 @@ __pycache__ *.old .DS_Store *.log.* -*.log \ No newline at end of file +*.log +train_log.txt +finetune/json2binidx_tool/data diff --git a/backend-golang/app.go b/backend-golang/app.go index 11e88b5..46e0110 100644 --- a/backend-golang/app.go +++ b/backend-golang/app.go @@ -9,6 +9,7 @@ import ( "path/filepath" "runtime" + "github.com/fsnotify/fsnotify" "github.com/minio/selfupdate" wruntime "github.com/wailsapp/wails/v2/pkg/runtime" ) @@ -41,6 +42,27 @@ func (a *App) OnStartup(ctx context.Context) { } a.downloadLoop() + + watcher, err := fsnotify.NewWatcher() + if err == nil { + watcher.Add("./lora-models") + watcher.Add("./models") + go func() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + wruntime.EventsEmit(ctx, "fsnotify", event.Name) + case _, ok := <-watcher.Errors: + if !ok { + return + } + } + } + }() + } } func (a *App) UpdateApp(url string) (broken bool, err error) { diff --git a/backend-golang/rwkv.go b/backend-golang/rwkv.go index fcccffa..c399738 100644 --- a/backend-golang/rwkv.go +++ b/backend-golang/rwkv.go @@ -31,6 +31,38 @@ func (a *App) ConvertModel(python string, modelPath string, strategy string, out return Cmd(python, "./backend-python/convert_model.py", "--in", modelPath, "--out", outPath, "--strategy", strategy) } +func (a *App) ConvertData(python string, input string, outputPrefix string, vocab string) (string, error) { + var err error + if python == "" { + python, err = GetPython() + } + if err != nil { + return "", err + } + tokenizerType := "HFTokenizer" + if strings.Contains(vocab, "rwkv_vocab_v20230424") { + tokenizerType = "RWKVTokenizer" + } + return Cmd(python, "./finetune/json2binidx_tool/tools/preprocess_data.py", "--input", input, "--output-prefix", outputPrefix, "--vocab", vocab, + "--tokenizer-type", tokenizerType, "--dataset-impl", "mmap", "--append-eod") +} + +func (a *App) MergeLora(python string, useGpu bool, loraAlpha int, baseModel string, loraPath string, outputPath string) (string, error) { + var err error + if python == "" { + python, err = GetPython() + } + if err != nil { + return "", err + } + args := []string{python, "./finetune/lora/merge_lora.py"} + if useGpu { + args = append(args, "--use-gpu") + } + args = append(args, strconv.Itoa(loraAlpha), baseModel, loraPath, outputPath) + return Cmd(args...) +} + func (a *App) DepCheck(python string) error { var err error if python == "" { diff --git a/backend-golang/wsl.go b/backend-golang/wsl.go new file mode 100644 index 0000000..72470b6 --- /dev/null +++ b/backend-golang/wsl.go @@ -0,0 +1,202 @@ +package backend_golang + +import ( + "bufio" + "context" + "errors" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" + + su "github.com/nyaosorg/go-windows-su" + wsl "github.com/ubuntu/gowsl" + wruntime "github.com/wailsapp/wails/v2/pkg/runtime" +) + +var distro *wsl.Distro +var stdin io.WriteCloser +var cmd *exec.Cmd + +func isWslRunning() (bool, error) { + if distro == nil { + return false, nil + } + state, err := distro.State() + if err != nil { + return false, err + } + if state != wsl.Running { + distro = nil + return false, nil + } + return true, nil +} + +func (a *App) WslStart() error { + if runtime.GOOS != "windows" { + return errors.New("wsl not supported") + } + + running, err := isWslRunning() + if err != nil { + return err + } + if running { + return nil + } + distros, err := wsl.RegisteredDistros(context.Background()) + if err != nil { + return err + } + for _, d := range distros { + if strings.Contains(d.Name(), "Ubuntu") { + distro = &d + break + } + } + if distro == nil { + return errors.New("ubuntu not found") + } + + cmd = exec.Command("wsl", "-d", distro.Name(), "-u", "root") + + stdin, err = cmd.StdinPipe() + if err != nil { + return err + } + + stdout, err := cmd.StdoutPipe() + cmd.Stderr = cmd.Stdout + if err != nil { + // stdin.Close() + stdin = nil + return err + } + + go func() { + reader := bufio.NewReader(stdout) + for { + if stdin == nil { + break + } + line, _, err := reader.ReadLine() + if err != nil { + wruntime.EventsEmit(a.ctx, "wslerr", err.Error()) + break + } + wruntime.EventsEmit(a.ctx, "wsl", string(line)) + } + // stdout.Close() + }() + + if err := cmd.Start(); err != nil { + return err + } + return nil +} + +func (a *App) WslCommand(command string) error { + if runtime.GOOS != "windows" { + return errors.New("wsl not supported") + } + + running, err := isWslRunning() + if err != nil { + return err + } + if !running { + return errors.New("wsl not running") + } + _, err = stdin.Write([]byte(command + "\n")) + if err != nil { + return err + } + return nil +} + +func (a *App) WslStop() error { + if runtime.GOOS != "windows" { + return errors.New("wsl not supported") + } + + running, err := isWslRunning() + if err != nil { + return err + } + if !running { + return errors.New("wsl not running") + } + err = cmd.Process.Kill() + cmd = nil + // stdin.Close() + stdin = nil + distro = nil + if err != nil { + return err + } + return nil +} + +func (a *App) WslIsEnabled() error { + if runtime.GOOS != "windows" { + return errors.New("wsl not supported") + } + + ex, err := os.Executable() + if err != nil { + return err + } + exDir := filepath.Dir(ex) + + data, err := os.ReadFile(exDir + "/wsl.state") + if err == nil { + if strings.Contains(string(data), "Enabled") { + return nil + } + } + + cmd := `-Command (Get-WindowsOptionalFeature -Online -FeatureName Microsoft-Windows-Subsystem-Linux).State | Out-File -Encoding utf8 -FilePath ` + exDir + "/wsl.state" + _, err = su.ShellExecute(su.RUNAS, "powershell", cmd, exDir) + if err != nil { + return err + } + time.Sleep(2 * time.Second) + data, err = os.ReadFile(exDir + "/wsl.state") + if err != nil { + return err + } + if strings.Contains(string(data), "Enabled") { + return nil + } else { + return errors.New("wsl is not enabled") + } +} + +func (a *App) WslEnable(forceMode bool) error { + if runtime.GOOS != "windows" { + return errors.New("wsl not supported") + } + + cmd := `/online /enable-feature /featurename:Microsoft-Windows-Subsystem-Linux` + _, err := su.ShellExecute(su.RUNAS, "dism", cmd, `C:\`) + if err != nil { + return err + } + if forceMode { + os.WriteFile("./wsl.state", []byte("Enabled"), 0644) + } + return nil +} + +func (a *App) WslInstallUbuntu() error { + if runtime.GOOS != "windows" { + return errors.New("wsl not supported") + } + + exec.Command("start", "ms-windows-store://pdp/?ProductId=9PN20MSR04DW").Start() + return nil +} diff --git a/backend-python/dep_check.py b/backend-python/dep_check.py index 82c10c2..883e498 100644 --- a/backend-python/dep_check.py +++ b/backend-python/dep_check.py @@ -1,8 +1,13 @@ +import lm_dataformat +import ftfy +import tqdm import tiktoken import GPUtil import torch import rwkv +import numpy +import tokenizers import fastapi import uvicorn import sse_starlette diff --git a/backend-python/requirements.txt b/backend-python/requirements.txt index d9030ae..8460ac7 100644 Binary files a/backend-python/requirements.txt and b/backend-python/requirements.txt differ diff --git a/backend-python/requirements_versions.txt b/backend-python/requirements_versions.txt index 3cb41fd..54e8033 100644 Binary files a/backend-python/requirements_versions.txt and b/backend-python/requirements_versions.txt differ diff --git a/backend-python/requirements_without_cyac.txt b/backend-python/requirements_without_cyac.txt index 34e0f0c..1bdbf4e 100644 Binary files a/backend-python/requirements_without_cyac.txt and b/backend-python/requirements_without_cyac.txt differ diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py index 135aa94..7d59c4c 100644 --- a/backend-python/routes/completion.py +++ b/backend-python/routes/completion.py @@ -95,27 +95,64 @@ async def eval_rwkv( return await asyncio.sleep(0.1) else: - completion_lock.acquire() - if await request.is_disconnected(): - completion_lock.release() + with completion_lock: + if await request.is_disconnected(): + requests_num = requests_num - 1 + print(f"{request.client} Stop Waiting (Lock)") + quick_log( + request, + None, + "Stop Waiting (Lock). RequestsNum: " + str(requests_num), + ) + return + set_rwkv_config(model, global_var.get(global_var.Model_Config)) + set_rwkv_config(model, body) + + response, prompt_tokens, completion_tokens = "", 0, 0 + for response, delta, prompt_tokens, completion_tokens in model.generate( + prompt, + stop=stop, + ): + if await request.is_disconnected(): + break + if stream: + yield json.dumps( + { + "object": "chat.completion.chunk" + if chat_mode + else "text_completion", + "response": response, + "model": model.name, + "choices": [ + { + "delta": {"content": delta}, + "index": 0, + "finish_reason": None, + } + if chat_mode + else { + "text": delta, + "index": 0, + "finish_reason": None, + } + ], + } + ) + # torch_gc() requests_num = requests_num - 1 - print(f"{request.client} Stop Waiting (Lock)") + if await request.is_disconnected(): + print(f"{request.client} Stop Waiting") + quick_log( + request, + body, + response + "\nStop Waiting. RequestsNum: " + str(requests_num), + ) + return quick_log( request, - None, - "Stop Waiting (Lock). RequestsNum: " + str(requests_num), + body, + response + "\nFinished. RequestsNum: " + str(requests_num), ) - return - set_rwkv_config(model, global_var.get(global_var.Model_Config)) - set_rwkv_config(model, body) - - response, prompt_tokens, completion_tokens = "", 0, 0 - for response, delta, prompt_tokens, completion_tokens in model.generate( - prompt, - stop=stop, - ): - if await request.is_disconnected(): - break if stream: yield json.dumps( { @@ -126,86 +163,47 @@ async def eval_rwkv( "model": model.name, "choices": [ { - "delta": {"content": delta}, + "delta": {}, "index": 0, - "finish_reason": None, + "finish_reason": "stop", } if chat_mode else { - "text": delta, + "text": "", "index": 0, - "finish_reason": None, + "finish_reason": "stop", } ], } ) - # torch_gc() - requests_num = requests_num - 1 - completion_lock.release() - if await request.is_disconnected(): - print(f"{request.client} Stop Waiting") - quick_log( - request, - body, - response + "\nStop Waiting. RequestsNum: " + str(requests_num), - ) - return - quick_log( - request, - body, - response + "\nFinished. RequestsNum: " + str(requests_num), - ) - if stream: - yield json.dumps( - { - "object": "chat.completion.chunk" - if chat_mode - else "text_completion", + yield "[DONE]" + else: + yield { + "object": "chat.completion" if chat_mode else "text_completion", "response": response, "model": model.name, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, "choices": [ { - "delta": {}, + "message": { + "role": "assistant", + "content": response, + }, "index": 0, "finish_reason": "stop", } if chat_mode else { - "text": "", + "text": response, "index": 0, "finish_reason": "stop", } ], } - ) - yield "[DONE]" - else: - yield { - "object": "chat.completion" if chat_mode else "text_completion", - "response": response, - "model": model.name, - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - }, - "choices": [ - { - "message": { - "role": "assistant", - "content": response, - }, - "index": 0, - "finish_reason": "stop", - } - if chat_mode - else { - "text": response, - "index": 0, - "finish_reason": "stop", - } - ], - } @router.post("/v1/chat/completions") @@ -372,81 +370,88 @@ async def embeddings(body: EmbeddingsBody, request: Request): return await asyncio.sleep(0.1) else: - completion_lock.acquire() - if await request.is_disconnected(): - completion_lock.release() - requests_num = requests_num - 1 - print(f"{request.client} Stop Waiting (Lock)") - quick_log( - request, - None, - "Stop Waiting (Lock). RequestsNum: " + str(requests_num), - ) - return + with completion_lock: + if await request.is_disconnected(): + requests_num = requests_num - 1 + print(f"{request.client} Stop Waiting (Lock)") + quick_log( + request, + None, + "Stop Waiting (Lock). RequestsNum: " + str(requests_num), + ) + return - base64_format = False - if body.encoding_format == "base64": - base64_format = True + base64_format = False + if body.encoding_format == "base64": + base64_format = True - embeddings = [] - prompt_tokens = 0 - if type(body.input) == list: - if type(body.input[0]) == list: - encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002") - for i in range(len(body.input)): - if await request.is_disconnected(): - break - input = encoding.decode(body.input[i]) - embedding, token_len = model.get_embedding(input, body.fast_mode) - prompt_tokens = prompt_tokens + token_len - if base64_format: - embedding = embedding_base64(embedding) - embeddings.append(embedding) - else: - for i in range(len(body.input)): - if await request.is_disconnected(): - break - embedding, token_len = model.get_embedding( - body.input[i], body.fast_mode + embeddings = [] + prompt_tokens = 0 + if type(body.input) == list: + if type(body.input[0]) == list: + encoding = tiktoken.model.encoding_for_model( + "text-embedding-ada-002" ) - prompt_tokens = prompt_tokens + token_len - if base64_format: - embedding = embedding_base64(embedding) - embeddings.append(embedding) - else: - embedding, prompt_tokens = model.get_embedding(body.input, body.fast_mode) - if base64_format: - embedding = embedding_base64(embedding) - embeddings.append(embedding) + for i in range(len(body.input)): + if await request.is_disconnected(): + break + input = encoding.decode(body.input[i]) + embedding, token_len = model.get_embedding( + input, body.fast_mode + ) + prompt_tokens = prompt_tokens + token_len + if base64_format: + embedding = embedding_base64(embedding) + embeddings.append(embedding) + else: + for i in range(len(body.input)): + if await request.is_disconnected(): + break + embedding, token_len = model.get_embedding( + body.input[i], body.fast_mode + ) + prompt_tokens = prompt_tokens + token_len + if base64_format: + embedding = embedding_base64(embedding) + embeddings.append(embedding) + else: + embedding, prompt_tokens = model.get_embedding( + body.input, body.fast_mode + ) + if base64_format: + embedding = embedding_base64(embedding) + embeddings.append(embedding) - requests_num = requests_num - 1 - completion_lock.release() - if await request.is_disconnected(): - print(f"{request.client} Stop Waiting") + requests_num = requests_num - 1 + if await request.is_disconnected(): + print(f"{request.client} Stop Waiting") + quick_log( + request, + None, + "Stop Waiting. RequestsNum: " + str(requests_num), + ) + return quick_log( request, None, - "Stop Waiting. RequestsNum: " + str(requests_num), + "Finished. RequestsNum: " + str(requests_num), ) - return - quick_log( - request, - None, - "Finished. RequestsNum: " + str(requests_num), - ) - ret_data = [ - { - "object": "embedding", - "index": i, - "embedding": embedding, + ret_data = [ + { + "object": "embedding", + "index": i, + "embedding": embedding, + } + for i, embedding in enumerate(embeddings) + ] + + return { + "object": "list", + "data": ret_data, + "model": model.name, + "usage": { + "prompt_tokens": prompt_tokens, + "total_tokens": prompt_tokens, + }, } - for i, embedding in enumerate(embeddings) - ] - - return { - "object": "list", - "data": ret_data, - "model": model.name, - "usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}, - } diff --git a/finetune/data/sample.jsonl b/finetune/data/sample.jsonl new file mode 100644 index 0000000..5d0ad99 --- /dev/null +++ b/finetune/data/sample.jsonl @@ -0,0 +1,7 @@ +{"text": "1:This is the first document."} +{"text": "2:Hello\nWorld"} +{"text": "3:1+1=2\n1+2=3\n2+2=4"} +{"text": "4:You will be training the GPT version because it's paralleziable and faster to train."} +{"text": "5:Read the inference code in src/model.py and try using the final hidden state(.xx .aa .bb)"} +{"text": "6:You can fine-tune the model with longer ctxLen and it can quickly adapt to longer ctxLens."} +{"text": "7:Consider RWKV 14B. The state has 200 vectors, that is, 5 vectors for each block: fp16 (xx), fp32 (aa), fp32 (bb), fp32 (pp), fp16 (xx)."} \ No newline at end of file diff --git a/finetune/get_layer_and_embd.py b/finetune/get_layer_and_embd.py new file mode 100644 index 0000000..17e3edd --- /dev/null +++ b/finetune/get_layer_and_embd.py @@ -0,0 +1,41 @@ +import torch +import sys +import time +import os +import threading +import gc + + +def file_cleaner(file): + last_pos = 0 + + def cleaner(): + nonlocal last_pos + while True: + time.sleep(0.1) + pos = file.tell() + if pos > last_pos: + os.posix_fadvise( + file.fileno(), last_pos, pos - last_pos, os.POSIX_FADV_DONTNEED + ) + last_pos = pos + + return cleaner + + +model_file = open(sys.argv[1], "rb") +cleaner = file_cleaner(model_file) +cleaner_thread = threading.Thread(target=cleaner, daemon=True) +cleaner_thread.start() + +w = torch.load(model_file, map_location="cpu") +gc.collect() + +n_embd = w["emb.weight"].shape[1] +n_layer = 0 +keys = list(w.keys()) +for x in keys: + layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0 + n_layer = max(n_layer, layer_id + 1) + +print(f"--n_layer {n_layer} --n_embd {n_embd}", end="") diff --git a/finetune/install-wsl-dep-and-train.sh b/finetune/install-wsl-dep-and-train.sh new file mode 100644 index 0000000..6faa48e --- /dev/null +++ b/finetune/install-wsl-dep-and-train.sh @@ -0,0 +1,46 @@ +if [[ ${cnMirror} == 1 ]]; then + export PIP_INDEX_URL="https://pypi.tuna.tsinghua.edu.cn/simple" + if grep -q "mirrors.aliyun.com" /etc/apt/sources.list; then + echo "apt cnMirror already set" + else + sudo sed -i 's/http:\/\/archive.ubuntu.com\/ubuntu\//http:\/\/mirrors.aliyun.com\/ubuntu\//g' /etc/apt/sources.list + sudo apt update + fi +fi + +if dpkg -s "python3-pip" >/dev/null 2>&1; then + echo "pip installed" +else + sudo apt install python3-pip +fi + +if dpkg -s "ninja-build" >/dev/null 2>&1; then + echo "ninja installed" +else + sudo apt install ninja-build +fi + +if dpkg -s "cuda" >/dev/null 2>&1; then + echo "cuda installed" +else + wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin + sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600 + wget https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb + sudo dpkg -i cuda-repo-wsl-ubuntu-11-7-local_11.7.0-1_amd64.deb + sudo cp /var/cuda-repo-wsl-ubuntu-11-7-local/cuda-*-keyring.gpg /usr/share/keyrings/ + sudo apt-get update + sudo apt-get -y install cuda +fi + +if python3 -c "import pkg_resources; pkg_resources.require(open('./finetune/requirements.txt',mode='r'))" &>/dev/null; then + echo "requirements satisfied" +else + python3 -m pip install -r ./finetune/requirements.txt +fi + +echo "loading $loadModel" +modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel) +echo $modelInfo + +python3 ./finetune/lora/train.py $modelInfo $@ --proj_dir lora-models --data_type binidx --lora \ + --lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu diff --git a/finetune/json2binidx_tool/tools/indexed_dataset.py b/finetune/json2binidx_tool/tools/indexed_dataset.py new file mode 100644 index 0000000..a1e0544 --- /dev/null +++ b/finetune/json2binidx_tool/tools/indexed_dataset.py @@ -0,0 +1,597 @@ +# Copyright (c) 2021, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +import os +import shutil +import struct +from functools import lru_cache +from itertools import accumulate + +import numpy as np +import torch + + + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return "cached" + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return "mmap" + else: + return None + else: + print(f"Dataset does not exist: {path}") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == "mmap": + return MMapIndexedDatasetBuilder( + out_file, dtype=__best_fitting_dtype(vocab_size) + ) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) + return None + if impl == "infer": + impl = infer_dataset_impl(path) + if impl == "lazy" and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == "cached" and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == "mmap" and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == "mmap": + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float32, + 7: np.float64, + 8: np.uint16, +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + + _HDR_MAGIC = b"TNTIDX\x00\x00" + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." + ) + version = f.read(8) + assert struct.unpack("= self._len: + raise IndexError("index out of range") + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return os.path.exists(index_file_path(path)) and os.path.exists( + data_file_path(path) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx : ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx : ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float32: 4, + np.float64: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, "wb") + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, np_array): + assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype + bytes = self.out_file.write(np_array) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in np_array.shape: + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(np_array.shape)) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), "rb") as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, "wb") + index.write(b"TNTIDX\x00\x00") + index.write(struct.pack(" 0: + doc_ids.append(text_ids) + if self.args.append_eod: + doc_ids[-1].append(Encoder.tokenizer.eod) + ids[key] = doc_ids + return ids, len(text) + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="input data") + group.add_argument( + "--input", + type=str, + required=True, + help="Path to input jsonl files or lmd archive(s) - if using multiple archives, put them in a comma separated " + "list", + ) + group.add_argument( + "--jsonl-keys", + nargs="+", + default=["text"], + help="space separate listed of keys to extract from jsonl. Defa", + ) + group.add_argument( + "--num-docs", + default=None, + help="Optional: Number of documents in the input data (if known) for an accurate progress bar.", + type=int, + ) + group = parser.add_argument_group(title="tokenizer") + group.add_argument( + "--tokenizer-type", + type=str, + required=True, + choices=[ + "HFGPT2Tokenizer", + "HFTokenizer", + "GPT2BPETokenizer", + "CharLevelTokenizer", + "TiktokenTokenizer", + "RWKVTokenizer", + ], + help="What type of tokenizer to use.", + ) + group.add_argument( + "--vocab-file", type=str, default=None, help="Path to the vocab file" + ) + group.add_argument( + "--merge-file", + type=str, + default=None, + help="Path to the BPE merge file (if necessary).", + ) + group.add_argument( + "--append-eod", + action="store_true", + help="Append an token to the end of a document.", + ) + group.add_argument("--ftfy", action="store_true", help="Use ftfy to clean text") + group = parser.add_argument_group(title="output data") + group.add_argument( + "--output-prefix", + type=str, + required=True, + help="Path to binary output file without suffix", + ) + group.add_argument( + "--dataset-impl", + type=str, + default="mmap", + choices=["lazy", "cached", "mmap"], + help="Dataset implementation to use. Default: mmap", + ) + + group = parser.add_argument_group(title="runtime") + group.add_argument( + "--workers", type=int, default=1, help="Number of worker processes to launch" + ) + group.add_argument( + "--log-interval", + type=int, + default=100, + help="Interval between progress updates", + ) + args = parser.parse_args() + args.keep_empty = False + + # some default/dummy values for the tokenizer + args.rank = 0 + args.make_vocab_size_divisible_by = 128 + args.model_parallel_size = 1 + + return args + + +def yield_from_files(fnames: list, semaphore): + """ + Iterator over input documents using lm_dataformat. Should be able to handle jsons / texts / + other compressed formats. Also filters out empty documents. + + :param fnames: list of filenames + """ + + def yielder(fname, semaphore): + for f in filter(lambda x: x, lmd.Reader(fname).stream_data()): + semaphore.acquire() + yield f + + for fname in fnames: + semaphore.acquire() + + yield from yielder(fname, semaphore) + + +def main(): + args = get_args() + encoder = Encoder(args) + tokenizer = build_tokenizer(args) + print(f"Vocab size: {tokenizer.vocab_size}") + print(f"Output prefix: {args.output_prefix}") + + # build a semaphore object to stop `yield_from_files` from getting ahead of encoder.encode and + # hence building up memory + semaphore = Semaphore(10000 + args.workers) + + # use multiprocessing to iterate over input documents + fin = yield_from_files(args.input.split(","), semaphore) + + if args.workers > 1: + pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) + encoded_docs = pool.imap(encoder.encode, fin, chunksize=25) + else: + encoder.initializer() + encoded_docs = (encoder.encode(doc) for doc in fin) + + # make a dataset builder for each key in args.jsonl_keys + # each key will output to a different file beginning with args.output_prefix + output_bin_files = {} + output_idx_files = {} + builders = {} + for key in args.jsonl_keys: + output_bin_files[key] = "{}_{}_{}.bin".format( + args.output_prefix, key, "document" + ) + output_idx_files[key] = "{}_{}_{}.idx".format( + args.output_prefix, key, "document" + ) + builders[key] = indexed_dataset.make_builder( + output_bin_files[key], + impl=args.dataset_impl, + vocab_size=tokenizer.vocab_size, + ) + + # actually do tokenization + proc_start = time.time() + total_bytes_processed = 0 + pbar = tqdm.tqdm() + for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): + total_bytes_processed += bytes_processed + + # release semaphore so `yield_from_files` can add another file to the buffer + semaphore.release() + + # add each tokenized document / sentence + for key, sentences in doc.items(): + for sentence in sentences: + builders[key].add_item(np.array(sentence, dtype=builders[key].dtype)) + # separate with eos token + builders[key].end_document() + + # log progress + if i % args.log_interval == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed / elapsed / 1024 / 1024 + pbar.set_description( + f"Processed {i}{'' if args.num_docs is None else '/' + str(args.num_docs)} documents ({i / elapsed:0.2f} docs/s, {mbs:0.2f} MB/s)." + ) + if i != 0: + pbar.update(args.log_interval) + + # save output file + for key in args.jsonl_keys: + builders[key].finalize(output_idx_files[key]) + + +if __name__ == "__main__": + main() diff --git a/finetune/json2binidx_tool/tools/rwkv_tokenizer.py b/finetune/json2binidx_tool/tools/rwkv_tokenizer.py new file mode 100644 index 0000000..e3f7126 --- /dev/null +++ b/finetune/json2binidx_tool/tools/rwkv_tokenizer.py @@ -0,0 +1,232 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +# Source: https://github.com/BlinkDL/ChatRWKV/blob/main/tokenizer/rwkv_tokenizer.py +######################################################################################################## + +import os, sys, time, random + +print(''' +####################################################################################################################### + +This tokenizer is not used in any RWKV models yet. I plan to use it for the future multilang RWKV models. + +Benefits: + +* Good support of most languages, from European to CJK to Arabic and Hindi and more. + +* Clean vocab. Good for code too. Vocab size = 65525 (use 0 for <|endoftext|>). + +* Good at numbers: the numerical tokens are '0'~'9', '10'~'99', ' 0'~' 9', ' 10'~' 99'. + +* Very easy tokenization: + +** The input text must be in UTF-8. + +** Greedy encoding: always pick the longest (in bytes) token (with the highest id) that matches your UTF-8 bytes. + +* The tokenization result is surprisingly good, because the vocab respects word boundaries and UTF-8 boundaries. + +For 10x faster speed: +mypyc rwkv_tokenizer.py +python3 -c "import rwkv_tokenizer" + +####################################################################################################################### +''') + +######################################################################################################## +# Tokenizer #1 (reference, naive, slow) +######################################################################################################## + +class RWKV_TOKENIZER(): + table = None # : list[list[list[bytes]]] = None + good = None # : list[set[int]] + wlen = None # : list[int] + def __init__(self, file_name): + self.vocab_size = 65525 + self.idx2token = {} + sorted = [] # must be already sorted + lines = open(file_name, "r", encoding="utf-8").readlines() + for l in lines: + idx = int(l[:l.index(' ')]) + x = eval(l[l.index(' '):l.rindex(' ')]) + x = x.encode("utf-8") if isinstance(x, str) else x + assert isinstance(x, bytes) + assert len(x) == int(l[l.rindex(' '):]) + sorted += [x] + self.idx2token[idx] = x + + self.token2idx = {} + for k, v in self.idx2token.items(): + self.token2idx[v] = int(k) + + # precompute some tables for fast matching + self.table = [[[] for j in range(256)] for i in range(256)] + self.good = [set() for i in range(256)] + self.wlen = [0 for i in range(256)] + + for i in reversed(range(len(sorted))): # reverse order - match longer tokens first + s = sorted[i] + if len(s) >= 2: + s0 = int(s[0]) + s1 = int(s[1]) + self.table[s0][s1] += [s] + self.wlen[s0] = max(self.wlen[s0], len(s)) + self.good[s0].add(s1) + + def encodeBytes(self, src: bytes): + src_len: int = len(src) + tokens = [] + i: int = 0 + while i < src_len: + s: bytes = src[i : i + 1] + + if i < src_len - 1: + s1: int = int(src[i + 1]) + s0: int = int(src[i]) + if s1 in self.good[s0]: + sss: bytes = src[i : i + self.wlen[s0]] + try: + s = next(filter(sss.startswith, self.table[s0][s1])) + except: + pass + tokens.append(self.token2idx[s]) + i += len(s) + + return tokens + + def decodeBytes(self, tokens): + return b''.join(map(lambda i: self.idx2token[i], tokens)) + + def encode(self, src: str): + return self.encodeBytes(src.encode("utf-8")) + + def decode(self, tokens): + return self.decodeBytes(tokens).decode('utf-8') + + def token_to_id(self, token): + return self.token2idx[token] + + def get_vocab_size(self): + return self.vocab_size + + def get_vocab(self): + return self.idx2token + + def printTokens(self, tokens): + for i in tokens: + s = self.idx2token[i] + try: + s = s.decode('utf-8') + except: + pass + print(f'{repr(s)}{i}', end=' ') + # print(repr(s), i) + print() + +######################################################################################################## +# Tokenizer #2 (trie, faster) https://github.com/TkskKurumi/ChatRWKV-TRIE-Tokenizer +######################################################################################################## + +class TRIE: + __slots__ = tuple("ch,to,values,front".split(",")) + to:list + values:set + def __init__(self, front=None, ch=None): + self.ch = ch + self.to = [None for ch in range(256)] + self.values = set() + self.front = front + + def __repr__(self): + fr = self + ret = [] + while(fr!=None): + if(fr.ch!=None): + ret.append(fr.ch) + fr = fr.front + return ""%(ret[::-1], self.values) + + def add(self, key:bytes, idx:int=0, val=None): + if(idx == len(key)): + if(val is None): + val = key + self.values.add(val) + return self + ch = key[idx] + if(self.to[ch] is None): + self.to[ch] = TRIE(front=self, ch=ch) + return self.to[ch].add(key, idx=idx+1, val=val) + + def find_longest(self, key:bytes, idx:int=0): + u:TRIE = self + ch:int = key[idx] + + while(u.to[ch] is not None): + u = u.to[ch] + idx += 1 + if(u.values): + ret = idx, u, u.values + if(idx==len(key)): + break + ch = key[idx] + return ret + +class TRIE_TOKENIZER(): + def __init__(self, file_name): + self.vocab_size = 65525 + self.idx2token = {} + sorted = [] # must be already sorted + with open(file_name, "r", encoding="utf-8") as f: + lines = f.readlines() + for l in lines: + idx = int(l[:l.index(' ')]) + x = eval(l[l.index(' '):l.rindex(' ')]) + x = x.encode("utf-8") if isinstance(x, str) else x + assert isinstance(x, bytes) + assert len(x) == int(l[l.rindex(' '):]) + sorted += [x] + self.idx2token[idx] = x + + self.token2idx = {} + for k,v in self.idx2token.items(): + self.token2idx[v] = int(k) + + self.root = TRIE() + for t, i in self.token2idx.items(): + _ = self.root.add(t, val=(t, i)) + + def encodeBytes(self, src:bytes): + idx:int = 0 + tokens = [] + while (idx < len(src)): + _idx:int = idx + idx, _, values = self.root.find_longest(src, idx) + assert(idx != _idx) + _, token = next(iter(values)) + tokens.append(token) + return tokens + + def decodeBytes(self, tokens): + return b''.join(map(lambda i: self.idx2token[i], tokens)) + + def encode(self, src): + return self.encodeBytes(src.encode("utf-8")) + + def decode(self, tokens): + return self.decodeBytes(tokens).decode('utf-8') + + def get_vocab_size(self): + return self.vocab_size + + def get_vocab(self): + return self.idx2token + + def printTokens(self, tokens): + for i in tokens: + s = self.idx2token[i] + try: + s = s.decode('utf-8') + except: + pass + print(f'{repr(s)}{i}', end=' ') + print() diff --git a/finetune/json2binidx_tool/tools/tokenizer.py b/finetune/json2binidx_tool/tools/tokenizer.py new file mode 100644 index 0000000..92f3b3d --- /dev/null +++ b/finetune/json2binidx_tool/tools/tokenizer.py @@ -0,0 +1,205 @@ +# Copyright (c) 2021, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Megatron tokenizers.""" + +from abc import ABC +from abc import abstractmethod + +from tokenizers import Tokenizer +from rwkv_tokenizer import RWKV_TOKENIZER, TRIE_TOKENIZER + +from typing import List, Union + + +def build_tokenizer(args): + """Initialize tokenizer.""" + if args.rank == 0: + print("> building {} tokenizer ...".format(args.tokenizer_type), flush=True) + + # Select and instantiate the tokenizer. + + if args.tokenizer_type.lower() == "HFTokenizer".lower(): + assert args.vocab_file is not None + tokenizer = HFTokenizer(args.vocab_file) + elif args.tokenizer_type.lower() == "RWKVTokenizer".lower(): + assert args.vocab_file is not None + tokenizer = RWKVTokenizer(args.vocab_file) + else: + raise NotImplementedError( + "{} tokenizer is not " "implemented.".format(args.tokenizer_type) + ) + + # Add vocab size. + args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args) + + return tokenizer + + +def _vocab_size_with_padding(orig_vocab_size, args): + """Pad vocab size so it is divisible by model parallel size and + still having GPU friendly size.""" + + after = orig_vocab_size + multiple = args.make_vocab_size_divisible_by * args.model_parallel_size + while (after % multiple) != 0: + after += 1 + if args.rank == 0: + print( + " > padded vocab (size: {}) with {} dummy tokens " + "(new size: {})".format(orig_vocab_size, after - orig_vocab_size, after), + flush=True, + ) + return after + + +class AbstractTokenizer(ABC): + """Abstract class for tokenizer.""" + + def __init__(self, name): + self.name = name + super().__init__() + + @property + @abstractmethod + def vocab_size(self): + pass + + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token.""" + pass + + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token.""" + pass + + @abstractmethod + def tokenize(self, text): + pass + + def detokenize(self, token_ids): + raise NotImplementedError( + "detokenizer is not implemented for {} " "tokenizer".format(self.name) + ) + + @property + def cls(self): + raise NotImplementedError( + "CLS is not provided for {} " "tokenizer".format(self.name) + ) + + @property + def sep(self): + raise NotImplementedError( + "SEP is not provided for {} " "tokenizer".format(self.name) + ) + + @property + def pad(self): + raise NotImplementedError( + "PAD is not provided for {} " "tokenizer".format(self.name) + ) + + @property + def eod(self): + raise NotImplementedError( + "EOD is not provided for {} " "tokenizer".format(self.name) + ) + + @property + def mask(self): + raise NotImplementedError( + "MASK is not provided for {} " "tokenizer".format(self.name) + ) + + +class HFTokenizer(AbstractTokenizer): + """Designed to Integrate HF's Tokenizer library.""" + + def __init__(self, vocab_file): + name = "HFTokenizer" + super().__init__(name) + + self.tokenizer = Tokenizer.from_file(vocab_file) + self.eod_id = self.tokenizer.token_to_id("<|endoftext|>") + self.pad_id = self.tokenizer.token_to_id("<|padding|>") + + @property + def vocab_size(self): + return self.tokenizer.get_vocab_size() + + @property + def vocab(self): + return self.tokenizer.get_vocab() + + @property + def inv_vocab(self): + return self.tokenizer.decoder + + def tokenize(self, text: str): + return self.tokenizer.encode(text).ids + + def tokenize_batch(self, text_batch: Union[List[str], str]): + return self.tokenizer.encode_batch(text_batch) + + def detokenize(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def eod(self): + return self.eod_id + + +class RWKVTokenizer(AbstractTokenizer): + """RWKV Worlds Tokenizer.""" + + def __init__(self, vocab_file='rwkv_vocab_v20230424.txt'): + name = "RWKVTokenizer" + super().__init__(name) + + self.tokenizer = TRIE_TOKENIZER(vocab_file) + self.eod_id = 0 # self.tokenizer.token_to_id("<|endoftext|>") + # self.pad_id = self.tokenizer.token_to_id("<|padding|>") + + @property + def vocab_size(self): + return self.tokenizer.get_vocab_size() + + @property + def vocab(self): + return self.tokenizer.get_vocab() + + @property + def inv_vocab(self): + return self.tokenizer.decode + + def tokenize(self, text: str): + return self.tokenizer.encode(text) + + def tokenize_batch(self, text_batch: Union[List[str], str]): + return self.tokenizer.encode_batch(text_batch) + + def detokenize(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def eod(self): + return self.eod_id diff --git a/finetune/lora/cuda/wkv_cuda.cu b/finetune/lora/cuda/wkv_cuda.cu new file mode 100644 index 0000000..3d5dadb --- /dev/null +++ b/finetune/lora/cuda/wkv_cuda.cu @@ -0,0 +1,133 @@ +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + const F *__restrict__ const _y, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const y = _y + _offset; + const F *__restrict__ const gy = _gy + _offset; + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F q[Tmax], r[Tmax]; + + F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + const F qq = gy[ii] / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = gu; + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + const F qq = q[i]; + const F rr = r[i]; + + F e1 = qq * exp(rr); + F e2 = exp(kk + pp); + gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); + gv[ii] = e1 + e2 * aa; + + const F ww = w + pp; + const F www = rr - u - kk; + const F p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/finetune/lora/cuda/wkv_cuda_bf16.cu b/finetune/lora/cuda/wkv_cuda_bf16.cu new file mode 100644 index 0000000..5b4e4e8 --- /dev/null +++ b/finetune/lora/cuda/wkv_cuda_bf16.cu @@ -0,0 +1,132 @@ +#include +#include +#include "ATen/ATen.h" +#define MIN_VALUE (-1e38) +typedef at::BFloat16 bf16; + +__global__ void kernel_forward(const int B, const int T, const int C, + const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, + bf16 *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + bf16 *__restrict__ const y = _y + _offset; + + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + float aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2)); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } +} + +__global__ void kernel_backward(const int B, const int T, const int C, + const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, + const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy, + bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + float u = float(_u[_c]); + float w = _w[_c]; + const bf16 *__restrict__ const k = _k + _offset; + const bf16 *__restrict__ const v = _v + _offset; + const bf16 *__restrict__ const y = _y + _offset; + const bf16 *__restrict__ const gy = _gy + _offset; + bf16 *__restrict__ const gk = _gk + _offset; + bf16 *__restrict__ const gv = _gv + _offset; + + float q[Tmax], r[Tmax]; + + float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + const float qq = float(gy[ii]) / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = bf16(gu); + + aa = 0, bb = 0, pp = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + const float yy = float(y[ii]); + const float qq = q[i]; + const float rr = r[i]; + + float e1 = qq * exp(rr); + float e2 = exp(kk + pp); + gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb)); + gv[ii] = bf16(e1 + e2 * aa); + + const float ww = w + pp; + const float www = rr - u - kk; + const float p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; + } +} + +void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); +} diff --git a/finetune/lora/cuda/wkv_op.cpp b/finetune/lora/cuda/wkv_op.cpp new file mode 100644 index 0000000..802021f --- /dev/null +++ b/finetune/lora/cuda/wkv_op.cpp @@ -0,0 +1,21 @@ +#include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/finetune/lora/cuda/wkv_op_bf16.cpp b/finetune/lora/cuda/wkv_op_bf16.cpp new file mode 100644 index 0000000..5783416 --- /dev/null +++ b/finetune/lora/cuda/wkv_op_bf16.cpp @@ -0,0 +1,25 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y); +void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, + torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), + gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/finetune/lora/merge_lora.py b/finetune/lora/merge_lora.py new file mode 100644 index 0000000..e43141b --- /dev/null +++ b/finetune/lora/merge_lora.py @@ -0,0 +1,53 @@ +from collections import OrderedDict +import os +import sys +from typing import Dict +import typing +import torch + +if '-h' in sys.argv or '--help' in sys.argv: + print(f'Usage: python3 {sys.argv[0]} [--use-gpu] ') + +if sys.argv[1] == '--use-gpu': + device = 'cuda' + lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5] +else: + device = 'cpu' + lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4] + + +with torch.no_grad(): + w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') + # merge LoRA-only slim checkpoint into the main weights + w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') + for k in w_lora.keys(): + w[k] = w_lora[k] + output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() + # merge LoRA weights + keys = list(w.keys()) + for k in keys: + if k.endswith('.weight'): + prefix = k[:-len('.weight')] + lora_A = prefix + '.lora_A' + lora_B = prefix + '.lora_B' + if lora_A in keys: + assert lora_B in keys + print(f'merging {lora_A} and {lora_B} into {k}') + assert w[lora_B].shape[1] == w[lora_A].shape[0] + lora_r = w[lora_B].shape[1] + w[k] = w[k].to(device=device) + w[lora_A] = w[lora_A].to(device=device) + w[lora_B] = w[lora_B].to(device=device) + w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) + output_w[k] = w[k].to(device='cpu', copy=True) + del w[k] + del w[lora_A] + del w[lora_B] + continue + + if 'lora' not in k: + print(f'retaining {k}') + output_w[k] = w[k].clone() + del w[k] + + torch.save(output_w, output) diff --git a/finetune/lora/src/__init__.py b/finetune/lora/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/finetune/lora/src/binidx.py b/finetune/lora/src/binidx.py new file mode 100644 index 0000000..369081a --- /dev/null +++ b/finetune/lora/src/binidx.py @@ -0,0 +1,269 @@ +from lib2to3.pgen2 import token +import os +import torch +import numpy as np +import shutil +import struct +from functools import lru_cache +from itertools import accumulate + +def print_rank_0(*message): + pass + # """If distributed is initialized print only on rank 0.""" + # if torch.distributed.is_initialized(): + # if torch.distributed.get_rank() == 0: + # print(*message, flush=True) + # else: + # print(*message, flush=True) + +def _warmup_mmap_file(path): + pass + # with open(path, "rb") as stream: + # while stream.read(100 * 1024 * 1024): + # pass + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: float, + 7: np.double, + 8: np.uint16, +} + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + +def index_file_path(prefix_path): + return prefix_path + ".idx" + +def data_file_path(prefix_path): + return prefix_path + ".bin" + +class MMapIndexedDataset(torch.utils.data.Dataset): + class Index(object): + _HDR_MAGIC = b"MMIDIDX\x00\x00" + + @classmethod + def writer(cls, path, dtype): + class _Writer(object): + def __enter__(self): + self._file = open(path, "wb") + + # Write Magic string so we can check the file format then opening it again. + self._file.write(cls._HDR_MAGIC) + # Write version number + # Little endian unsigned 64 Bit integer + self._file.write(struct.pack(" 0: + self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') + self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size + + if args.my_pile_stage > 0: + # assert self.data_size == 332115325534 and self.vocab_size == 50277 + self.samples_per_epoch = args.epoch_steps * args.real_bsz + assert self.samples_per_epoch == 40320 + rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") + dataset_slot = self.data_size // args.ctx_len + if args.my_pile_stage != 4: + assert MaybeIsPrime(args.magic_prime) + assert args.magic_prime % 3 == 2 + assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 + elif args.data_type == "numpy": + self.data = np.load(args.data_file).astype("int") + self.vocab_size = args.vocab_size + rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") + self.data_size = len(self.data) + rank_zero_info(f"Data has {self.data_size} tokens.") + elif args.data_type == "uint16": + self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) + self.vocab_size = args.vocab_size + rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") + self.data_size = self.data.shape[0] + rank_zero_info(f"Data has {self.data_size} samples.") + elif args.data_type == "wds_img": + self.vocab_size = -1 + self.data_size = -1 + self.data = None + self.error_count = 0 + else: + if args.data_type == "dummy": + rank_zero_info("Building dummy data...") + self.data = "" + for i in range(100000): + aa = (i) % 10000 + bb = (i * i) % 10000 + cc = aa + bb + self.data += f".{aa}+{bb}={cc}." + else: + self.data = open(args.data_file, "r", encoding=args.data_type).read() + rank_zero_info("Building token list...") + unique = sorted(list(set(self.data))) + self.vocab_size = len(unique) + # rank_zero_info() + # for u in unique: + # print(u, end=' ') + # rank_zero_info('\n\n') + xx = 0 + xxObj = {} + for u in unique: + xxObj[xx] = u + xx += 1 + with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file: + vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) + self.data_size = len(self.data) + rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.") + self.stoi = {ch: i for i, ch in enumerate(unique)} + self.itos = {i: ch for i, ch in enumerate(unique)} + + def __len__(self): + return self.args.epoch_steps * self.args.micro_bsz + + def __getitem__(self, idx): + args = self.args + rank = self.global_rank + epoch = self.real_epoch + world_size = self.world_size + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") + + if args.data_type == "wds_img": + def init_wds(self, bias=0): + def identity(x): + return x + import webdataset as wds + import torchvision.transforms as transforms + # img_transform = transforms.Compose( + # [transforms.CenterCrop(256)] + # ) + img_transform = transforms.Compose([ + transforms.CenterCrop(512), + transforms.Resize((args.my_img_size)) + ]) + self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) + for pp in self.data_raw.pipeline: + if 'Resampled' in str(pp): + pp.deterministic = True + def worker_seed(): + return rank*100000+epoch+bias*1e9 + pp.worker_seed = worker_seed + self.data = iter(self.data_raw) + # print(f"WebDataset loaded for rank {rank} epoch {epoch}") + if self.data == None: + init_wds(self) + trial = 0 + while trial < 10: + try: + dd = next(self.data) # jpg, json, txt + break + except: + print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]') + self.error_count += 1 + init_wds(self, self.error_count) + trial += 1 + pass + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}") + # with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp: + # tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n") + return dd[0], dd[2] + else: + if args.data_type == "uint16": + i = np.random.randint(0, self.data_size-1) + dix = self.data[i] + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + else: + ctx_len = args.ctx_len + req_len = ctx_len + 1 + magic_prime = args.magic_prime + data = self.data + + if args.my_pile_stage > 0 and args.my_pile_stage != 4: + ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank + + if args.my_qa_mask > 0: + ii_orig = ii + if ii % 2 == 0: + ii = (ii // 2) * args.magic_prime + if args.ctx_len == 1024: + magic_prime = 324331313 + elif args.ctx_len == 2048: + magic_prime = 162165671 + elif args.ctx_len == 4096: + magic_prime = 81082817 + data = self.data_pile + else: + ii = ii // 2 + + factor = (math.sqrt(5) - 1) / 2 + factor = int(magic_prime * factor) + i = ((factor * ii * ii * ii) % magic_prime) * ctx_len + if (args.my_qa_mask == 0) or (data == self.data_pile): + i = i + args.my_pile_shift + # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") + else: + # cheat: pick a random spot in dataset + i = np.random.randint(0, self.data_size - req_len) + + if args.data_type == "binidx": + dix = data.get(idx=0, offset=i, length=req_len).astype(int) + elif args.data_type == "numpy": + dix = data[i : i + req_len] + else: + dix = [self.stoi[s] for s in data[i : i + req_len]] + + if args.my_qa_mask == 1: + if data == self.data_pile: + z = [1] * ctx_len + else: + z = [0] * ctx_len + z_sum = 0 + isGood = False + for i in range(3, ctx_len): + if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187: + isGood = True + if dix[i] == 0: + isGood = False + if isGood: + z[i] = 1 + z_sum += 1 + if z_sum == 0: + z = [1] * ctx_len + i = np.random.randint(0, self.data_pile_size - req_len) + dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int) + z = torch.tensor(z, dtype=torch.bfloat16) + + x = torch.tensor(dix[:-1], dtype=torch.long) + y = torch.tensor(dix[1:], dtype=torch.long) + + # if ii_orig < 50: + # # if rank == 1: + # print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:]) + # else: + # exit(0) + + if args.my_qa_mask == 1: + return x, y, z + + return x, y diff --git a/finetune/lora/src/model.py b/finetune/lora/src/model.py new file mode 100644 index 0000000..15f8d82 --- /dev/null +++ b/finetune/lora/src/model.py @@ -0,0 +1,678 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import functools +import os, math, gc, importlib +import torch +# torch._C._jit_set_profiling_executor(True) +# torch._C._jit_set_profiling_mode(True) +import torch.nn as nn +from torch.utils.checkpoint import checkpoint as torch_checkpoint +from torch.nn import functional as F +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from pytorch_lightning.strategies import DeepSpeedStrategy +if importlib.util.find_spec('deepspeed'): + import deepspeed + from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam + +# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam + +LORA_CONFIG = { + "r": 0, + "alpha": 0, + "dropout": 0, + "parts": {"att", "ln", "time"}, +} + + +try: + print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"]) +except: + os.environ["RWKV_MY_TESTING"] = '' + +def __nop(ob): + return ob + + +MyModule = nn.Module +MyFunction = __nop +if os.environ["RWKV_JIT_ON"] == "1": + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + +T_MAX = int(os.environ["RWKV_T_MAX"]) # TAKES LOTS OF VRAM! +# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice + +from torch.utils.cpp_extension import load + +if os.environ["RWKV_FLOAT_MODE"] == "bf16": + wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["finetune/lora/cuda/wkv_op_bf16.cpp", "finetune/lora/cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 32) == 0 + w = -torch.exp(w.float().contiguous()) + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + ctx.save_for_backward(w, u, k, v, y) + return y + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 32) == 0 + w, u, k, v, y = ctx.saved_tensors + gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + return (None, None, None, gw, gu, gk, gv) +else: + wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["finetune/lora/cuda/wkv_op.cpp", "finetune/lora/cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 32) == 0 + if "32" in os.environ["RWKV_FLOAT_MODE"]: + w = -torch.exp(w.contiguous()) + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + else: + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + ctx.save_for_backward(w, u, k, v, y) + if "32" in os.environ["RWKV_FLOAT_MODE"]: + return y + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": + return y.half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + return y.bfloat16() + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 32) == 0 + w, u, k, v, y = ctx.saved_tensors + gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) + gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) + gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) + gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) + if "32" in os.environ["RWKV_FLOAT_MODE"]: + wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) + else: + wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + if "32" in os.environ["RWKV_FLOAT_MODE"]: + return (None, None, None, gw, gu, gk, gv) + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + + +def RUN_CUDA(B, T, C, w, u, k, v): + return WKV.apply(B, T, C, w, u, k, v) + + +######################################################################################################## +# LoRA +######################################################################################################## + + +class LoraLinear(nn.Module): + + def __init__(self, in_features: int, out_features: int, bias: bool): + super().__init__() + + self.weight = nn.Parameter(torch.empty((out_features, in_features))) + assert bias == False, "Biased LoraLinear not supported" + + r, alpha, dropout = LORA_CONFIG["r"], LORA_CONFIG[ + "alpha"], LORA_CONFIG["dropout"] + self.lora_A = nn.Parameter(torch.empty(r, in_features)) + self.lora_B = nn.Parameter(torch.empty(out_features, r)) + self.lora_dropout = nn.Dropout(dropout) + self.scaling = alpha / r + + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def forward(self, x): + return ( + F.linear(x, self.weight) + self.scaling * + F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B)) + + +@functools.wraps(LoraLinear) +def make_linear_att(*args, **kwargs): + if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0: + return LoraLinear(*args, **kwargs) + else: + return nn.Linear(*args, **kwargs) + + +@functools.wraps(LoraLinear) +def make_linear_ffn(*args, **kwargs): + if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0: + return LoraLinear(*args, **kwargs) + else: + return nn.Linear(*args, **kwargs) + + +######################################################################################################## +# RWKV: RWKV Time-mix + RWKV Channel-mix +######################################################################################################## + + +class RWKV_TimeMix(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.ctx_len = args.ctx_len + self.n_embd = args.n_embd + + with torch.no_grad(): # fancy init + ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + + # fancy time_decay + decay_speed = torch.ones(args.dim_att) + for h in range(args.dim_att): + decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + self.time_decay = nn.Parameter(decay_speed) + # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) + + # fancy time_first + zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5 + self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag) + + # fancy time_mix + self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.key = make_linear_att(args.n_embd, args.dim_att, bias=False) + self.value = make_linear_att(args.n_embd, args.dim_att, bias=False) + self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False) + + self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) + + if 'a' in os.environ["RWKV_MY_TESTING"]: + self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + d_qkv = args.n_embd // 16 + self.qq = nn.Linear(args.n_embd, d_qkv, bias=False) + self.kk = nn.Linear(args.n_embd, d_qkv, bias=False) + self.vv = nn.Linear(args.n_embd, d_qkv, bias=False) + self.oo = nn.Linear(d_qkv, args.n_embd, bias=False) + with torch.no_grad(): + self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + + if 'a' not in os.environ["RWKV_MY_TESTING"]: + @MyFunction + def jit_func(self, x): + xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + sr = torch.sigmoid(r) + return sr, k, v + + def forward(self, x): + B, T, C = x.size() # x = (Batch,Time,Channel) + sr, k, v = self.jit_func(x) + rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) + return self.output(rwkv) + + if 'a' in os.environ["RWKV_MY_TESTING"]: + @MyFunction + def QKV(self, q, k, v): + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.att_mask == 0, float('-inf')) + att = F.softmax(att, dim = -1) + x = att @ v + return x + + @MyFunction + def jit_funcQKV(self, x): + xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq) + xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk) + xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv) + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + sr = torch.sigmoid(r) + qq = self.qq(xqq) + kk = self.kk(xkk) + vv = self.vv(xvv) + return sr, k, v, qq, kk, vv + + def forward(self, x): + B, T, C = x.size() # x = (Batch,Time,Channel) + sr, k, v, qq, kk, vv = self.jit_funcQKV(x) + rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) + rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv)) + return rwkv + +######################################################################################################## + +class RWKV_ChannelMix(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) + + self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False) + self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False) + self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False) + + @MyFunction + def forward(self, x): + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + k = self.key(xk) + k = torch.square(torch.relu(k)) + kv = self.value(k) + return torch.sigmoid(self.receptance(xr)) * kv + +class MishGLU(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) + + x = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + x[0, 0, i] = i / args.n_embd + + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False) + self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) + + @MyFunction + def forward(self, x): + xx = self.time_shift(x) + xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xb = x * self.time_mix_r + xx * (1 - self.time_mix_r) + a = self.aa(xa) + b = self.bb(xb) + return self.value(a * F.mish(b)) + +######################################################################################################## +# The RWKV Model with our blocks +######################################################################################################## + + +class Block(nn.Module): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + + self.ln1 = nn.LayerNorm(args.n_embd) + self.ln2 = nn.LayerNorm(args.n_embd) + + if self.layer_id == 0: + self.ln0 = nn.LayerNorm(args.n_embd) + if args.my_pos_emb > 0: + self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd))) + self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd))) + + if self.layer_id == 0 and self.args.pre_ffn > 0: + self.ffnPre = RWKV_ChannelMix(args, 0) + else: + self.att = RWKV_TimeMix(args, layer_id) + + if 'g' in os.environ["RWKV_MY_TESTING"]: + self.ffn = MishGLU(args, layer_id) + else: + self.ffn = RWKV_ChannelMix(args, layer_id) + + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: + self.tiny_ln = nn.LayerNorm(args.n_embd) + self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) + self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) + self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False) + self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + + def forward(self, x, x_emb=None): + args = self.args + B, T, C = x.size() + if self.layer_id == 0: + x = self.ln0(x) + if args.my_pos_emb > 0: + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] + x = x + pos_emb + + if self.layer_id == 0 and args.pre_ffn > 0: + x = x + self.ffnPre(self.ln1(x)) + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: + xx = self.tiny_ln(x) + q = self.tiny_q(xx)[:, :T, :] + k = self.tiny_k(xx)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) + c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) + x = x + c @ self.tiny_v(x_emb) + return x + + +class L2Wrap(torch.autograd.Function): + @staticmethod + def forward(ctx, loss, y): + ctx.save_for_backward(y) + return loss + + @staticmethod + def backward(ctx, grad_output): + y = ctx.saved_tensors[0] + # to encourage the logits to be close to 0 + factor = 1e-4 / (y.shape[0] * y.shape[1]) + maxx, ids = torch.max(y, -1, keepdim=True) + gy = torch.zeros_like(y) + gy.scatter_(-1, ids, maxx * factor) + return (grad_output, gy) + + +class RWKV(pl.LightningModule): + def __init__(self, args): + super().__init__() + self.args = args + if not hasattr(args, 'dim_att'): + args.dim_att = args.n_embd + if not hasattr(args, 'dim_ffn'): + args.dim_ffn = args.n_embd * 4 + if not hasattr(args, 'tiny_att_layer'): + args.tiny_att_layer = -1 + if not hasattr(args, 'tiny_att_dim'): + args.tiny_att_dim = -1 + + self.emb = nn.Embedding(args.vocab_size, args.n_embd) + + self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) + + self.ln_out = nn.LayerNorm(args.n_embd) + self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) + + if args.head_qk > 0: + self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False) + self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) + self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + + def configure_optimizers(self): + args = self.args + if args.layerwise_lr > 0: + lr_1x = set() + lr_2x = set() + lr_3x = set() + for n, p in self.named_parameters(): + if "time_mix" in n: + if args.my_pile_stage == 2: + lr_2x.add(n) + else: + lr_1x.add(n) + elif "time_decay" in n: + if args.my_pile_stage == 2: + lr_3x.add(n) + else: + lr_2x.add(n) + elif "time_first" in n: + lr_3x.add(n) + else: + lr_1x.add(n) + lr_1x = sorted(list(lr_1x)) + lr_2x = sorted(list(lr_2x)) + lr_3x = sorted(list(lr_3x)) + # print('1x', lr_1x) + # print('2x', lr_2x) + # print('3x', lr_3x) + param_dict = {n: p for n, p in self.named_parameters()} + if args.my_pile_stage == 2: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + ] + else: + optim_groups = [ + {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, + {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, + {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + ] + else: + optim_groups = [ + {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + ] + + for g in optim_groups: + g["params"] = [p for p in g["params"] if p.requires_grad] + optim_groups = [g for g in optim_groups if len(g["params"]) > 0] + + if self.deepspeed_offload: + return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) + return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) + + @property + def deepspeed_offload(self) -> bool: + strategy = self.trainer.strategy + if isinstance(strategy, DeepSpeedStrategy): + cfg = strategy.config["zero_optimization"] + return cfg.get("offload_optimizer") or cfg.get("offload_param") + return False + + def forward(self, idx): + args = self.args + B, T = idx.size() + assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." + + x = self.emb(idx) + x_emb = x + + if args.tiny_att_dim > 0: + for block in self.blocks: + if args.grad_cp == 1: + if args.lora: + x = torch_checkpoint(block, x, x_emb, use_reentrant=False) + else: + x = deepspeed.checkpointing.checkpoint(block, x, x_emb) + else: + x = block(x, x_emb) + else: + for block in self.blocks: + if args.grad_cp == 1: + if args.lora: + x = torch_checkpoint(block, x, x_emb, use_reentrant=False) + else: + x = deepspeed.checkpointing.checkpoint(block, x) + else: + x = block(x) + + x = self.ln_out(x) + + if args.head_qk > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + + if "32" in os.environ["RWKV_FLOAT_MODE"]: + c = c @ F.one_hot(idx, num_classes=args.vocab_size) + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() + + x = self.head(x) + c + else: + x = self.head(x) + + return x + + def training_step(self, batch, batch_idx): + args = self.args + if args.my_qa_mask != 1: + idx, targets = batch + logits = self(idx) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + else: + idx, targets, mask = batch + mask = mask.view(-1) + sum_mask = torch.sum(mask).item() + # if sum_mask == 0: + # return torch.tensor([0.0], requires_grad=True) + + logits = self(idx) + if sum_mask == mask.shape[0]: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + # print('rank', self.global_rank, 'loss', loss.item()) + else: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') + # loss_raw = loss + loss = torch.sum(loss * mask) / sum_mask + + # torch.set_printoptions(threshold=10000) + # if True: #self.global_rank == 1: + # tmp = '' + # sss = 0 + # ccc = 0 + # for i in range(mask.shape[0]): + # if mask[i] > 0: + # tmp += str(idx.view(-1)[i].item()) + ',' + # sss += loss_raw.view(-1)[i].float().item() + # ccc += 1 + # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx) + + return L2Wrap.apply(loss, logits) + + def training_step_end(self, batch_parts): + all = self.all_gather(batch_parts) + if self.trainer.is_global_zero: + self.trainer.my_loss_all = all + + def generate_init_weight(self): + print( + f""" +############################################################################ +# +# Init model weight (slow for large models)... +# +############################################################################ +""" + ) + m = {} + for n in self.state_dict(): + p = self.state_dict()[n] + shape = p.shape + + gain = 1.0 + scale = 1.0 + if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n: + m[n] = p + else: + if n == "emb.weight": + scale = -1 * self.args.lr_init + else: + if shape[0] > shape[1]: + gain = math.sqrt(shape[0] / shape[1]) + for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']: + if kk in n: + scale = 0 + if n == "head.weight": + scale = 0.5 + if "head_k." in n: + scale = 0.1 + if "head_q." in n: + scale = 0 + + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}") + + if self.args.accelerator.upper() == "GPU": + m[n] = torch.empty((shape[0], shape[1]), device="cuda") + else: + m[n] = torch.empty((shape[0], shape[1])) + + if scale == 0: + nn.init.zeros_(m[n]) + elif scale < 0: + nn.init.uniform_(m[n], a=scale, b=-scale) + else: + nn.init.orthogonal_(m[n], gain=gain * scale) + + m[n] = m[n].cpu() + if os.environ["RWKV_FLOAT_MODE"] == "fp16": + m[n] = m[n].half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + m[n] = m[n].bfloat16() + + # if n == "emb.weight": + # print(m[n]) + + gc.collect() + torch.cuda.empty_cache() + return m diff --git a/finetune/lora/src/trainer.py b/finetune/lora/src/trainer.py new file mode 100644 index 0000000..ab65776 --- /dev/null +++ b/finetune/lora/src/trainer.py @@ -0,0 +1,203 @@ +import os, math, time, datetime, subprocess +import torch +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +from .model import LORA_CONFIG + +def my_save(dd, ff): + if '14b-run1' not in ff: + torch.save(dd, ff) + else: + fn = ff.split('/')[-1] + fff = '/dev/shm/' + fn + torch.save(dd, fff) + subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True) + +class train_callback(pl.Callback): + def __init__(self, args): + super().__init__() + self.args = args + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + args = self.args + # if args.cuda_cleanup > 0: + # torch.cuda.empty_cache() + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + + # LR schedule + w_step = args.warmup_steps + if args.lr_final == args.lr_init or args.epoch_count == 0: + lr = args.lr_init + else: + decay_step = real_step - args.my_pile_edecay * args.epoch_steps + decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps + progress = (decay_step - w_step + 1) / (decay_total - w_step) + progress = min(1, max(0, progress)) + + if args.lr_final == 0 or args.lr_init == 0: # linear decay + lr = args.lr_init + (args.lr_final - args.lr_init) * progress + else: # exp decay + lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) + + if trainer.global_step < w_step: + lr = lr * (0.2 + 0.8 * trainer.global_step / w_step) + # if trainer.is_global_zero: + # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) + + for param_group in trainer.optimizers[0].param_groups: + if args.layerwise_lr > 0: + param_group["lr"] = lr * param_group["my_lr_scale"] + # print(param_group["lr"], param_group["my_lr_scale"]) + else: + param_group["lr"] = lr + + trainer.my_lr = lr + # rank_zero_info(f"{real_step} {lr}") + + if trainer.global_step == 0: + if trainer.is_global_zero: # logging + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") + trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n") + try: + print(f"\n{trainer.strategy.config}\n") + trainer.my_log.write(f"{trainer.strategy.config}\n") + except: + pass + trainer.my_log.flush() + if len(args.wandb) > 0: + print("Login to wandb...") + import wandb + wandb.init( + project=args.wandb, + name=args.run_name + " " + args.my_timestamp, + config=args, + save_code=False, + ) + trainer.my_wandb = wandb + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + args = self.args + if trainer.is_global_zero: # logging + t_now = time.time_ns() + token_per_step = args.ctx_len * args.real_bsz + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + kt_s = 0 + try: + t_cost = (t_now - trainer.my_time_ns) / 1e9 + kt_s = token_per_step / t_cost / 1000 + self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) + self.log("Kt/s", kt_s, prog_bar=True, on_step=True) + except: + pass + trainer.my_time_ns = t_now + trainer.my_loss = trainer.my_loss_all.float().mean().item() + trainer.my_loss_sum += trainer.my_loss + trainer.my_loss_count += 1 + trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count + self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) + self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) + # self.log("s", real_step, prog_bar=True, on_step=True) + + if len(args.wandb) > 0: + lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9} + if kt_s > 0: + lll["kt/s"] = kt_s + trainer.my_wandb.log(lll, step=int(real_step)) + if args.magic_prime > 0: + expand_factor = 2 if args.my_qa_mask > 0 else 1 + if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1: + to_save_dict = pl_module.state_dict() + my_save( + to_save_dict, + f"{args.proj_dir}/rwkv-final.pth", + ) + + + def on_train_epoch_start(self, trainer, pl_module): + args = self.args + dataset = trainer.train_dataloader.dataset.datasets + assert "MyDataset" in str(dataset) + dataset.global_rank = trainer.global_rank + dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) + dataset.world_size = trainer.world_size + # print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########') + + def on_train_epoch_end(self, trainer, pl_module): + args = self.args + if trainer.is_global_zero: # logging & save state_dict + if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: + if args.data_type == 'wds_img': + raw_dict = pl_module.state_dict() + to_save_dict = {} + for k in raw_dict: + if k.startswith('encoder.') or k.startswith('decoder.'): + to_save_dict[k] = raw_dict[k] + else: + to_save_dict = pl_module.state_dict() + + if args.lora: + enable_time_finetune = 'time' in LORA_CONFIG["parts"] + enable_ln_finetune = 'ln' in LORA_CONFIG["parts"] + lora_dict = {} + for name, state in to_save_dict.items(): + if ('.lora_' in name + or (enable_time_finetune and '.time_' in name) + or (enable_ln_finetune and '.ln' in name)): + lora_dict[name] = state + to_save_dict = lora_dict + + try: + my_save( + to_save_dict, + f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", + ) + except Exception as e: + print('Error\n\n', e, '\n\n') + trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") + trainer.my_log.flush() + + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + + +@rank_zero_only +def generate_init_weight(model, init_weight_name): + mm = model.generate_init_weight() + + if model.args.my_pile_stage == 1: + if len(model.args.load_model) > 0: + print(f"Combine weights from {model.args.load_model}...") + load_dict = torch.load(model.args.load_model, map_location="cpu") + for k in load_dict: + assert k in mm + src = load_dict[k] + try: + mm[k] = src.reshape(mm[k].shape) + except: + tmp = mm[k].squeeze().clone() + print(k, src.shape, '-->', mm[k].shape) + ss = src.shape[0] + dd = tmp.shape[0] + for i in range(dd): + pos = i / dd * ss + if pos >= ss - 1: + tmp[i] = src[ss-1] + else: + p0 = int(math.floor(pos)) + ii = pos - p0 + tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii) + mm[k] = tmp.reshape(mm[k].shape) + sss = src.squeeze().float().cpu().numpy() + print(sss[:10], '...', sss[-10:]) + mmm = mm[k].squeeze().float().cpu().numpy() + print(mmm[:10], '...', mmm[-10:]) + + print(f"Save to {init_weight_name}...") + torch.save(mm, init_weight_name) + + if model.args.my_pile_stage == 1: + print("Done. Now go for stage 2.") + exit(0) diff --git a/finetune/lora/src/utils.py b/finetune/lora/src/utils.py new file mode 100644 index 0000000..ea25990 --- /dev/null +++ b/finetune/lora/src/utils.py @@ -0,0 +1,130 @@ +import json, time, random, os +import numpy as np +import torch +from torch.nn import functional as F + +time_slot = {} +time_ref = time.time_ns() + +def record_time(name): + if name not in time_slot: + time_slot[name] = 1e20 + tt = (time.time_ns() - time_ref) / 1e9 + if tt < time_slot[name]: + time_slot[name] = tt + +class TOKENIZER(): + def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): + if 'list' in str(type(WORD_NAME)): + self.charMode = False + if WORD_NAME[0] == WORD_NAME[1]: + from transformers import PreTrainedTokenizerFast + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) + else: + from transformers import GPT2TokenizerFast + self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) + self.vocab_size = len(self.tokenizer) + else: + self.charMode = True + with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: + self.word_table = json.load(result_file) + + self.vocab_size = len(self.word_table) + + self.stoi = {v: int(k) for k, v in self.word_table.items()} + self.itos = {int(k): v for k, v in self.word_table.items()} + + self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] + + def refine_context(self, context): + context = context.strip().split('\n') + for c in range(len(context)): + context[c] = context[c].strip().strip('\u3000').strip('\r') + context = list(filter(lambda c: c != '', context)) + context = '\n' + ('\n'.join(context)).strip() + if context == '': + context = '\n' + return context + + def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + # out[self.UNKNOWN_CHAR] = -float('Inf') + lastChar = int(x[-1]) + + probs = F.softmax(out, dim=-1) + + if self.charMode: + if self.itos[lastChar] == '\n': + top_p = top_p_newline + else: + top_p = top_p_usual + else: + top_p = top_p_usual + + if os.environ["RWKV_RUN_DEVICE"] == "cpu": + probs = probs.numpy() + sorted_probs = np.sort(probs)[::-1] + cumulative_probs = np.cumsum(sorted_probs) + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + probs = probs / np.sum(probs) + out = np.random.choice(a=len(probs), p=probs) + return out + else: + sorted_probs = torch.sort(probs, descending=True)[0] + cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + out = torch.multinomial(probs, num_samples=1)[0] + return out + +def MaybeIsPrime(number): + if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): + return True + else: + return False + + +def FermatPrimalityTest(number): + if number > 1: + for time in range(3): + randomNumber = random.randint(2, number) - 1 + if pow(randomNumber, number - 1, number) != 1: + return False + return True + else: + return False + + +def MillerRabinPrimalityTest(number): + if number == 2: + return True + elif number == 1 or number % 2 == 0: + return False + oddPartOfNumber = number - 1 + timesTwoDividNumber = 0 + while oddPartOfNumber % 2 == 0: + oddPartOfNumber = oddPartOfNumber // 2 + timesTwoDividNumber = timesTwoDividNumber + 1 + + for time in range(3): + while True: + randomNumber = random.randint(2, number) - 1 + if randomNumber != 0 and randomNumber != 1: + break + + randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) + + if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): + iterationNumber = 1 + + while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1): + randomNumberWithPower = pow(randomNumberWithPower, 2, number) + iterationNumber = iterationNumber + 1 + if randomNumberWithPower != (number - 1): + return False + + return True diff --git a/finetune/lora/train.py b/finetune/lora/train.py new file mode 100644 index 0000000..0548dab --- /dev/null +++ b/finetune/lora/train.py @@ -0,0 +1,388 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +if __name__ == "__main__": + from argparse import ArgumentParser + from pytorch_lightning import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + + rank_zero_info("########## work in progress ##########") + + ######################################################################################################## + # + # example: train a simple L12-D768 RWKV on dummy data + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "" --data_type "dummy" --vocab_size 0 \ + # --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \ + # --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \ + # --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + + # example: train a simple L6-D512 RWKV from scratch on enwik8 + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \ + # --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \ + # --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \ + # --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + + # example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M + # + # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ + # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ + # --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \ + # --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ + # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ + # --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 + + # example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow + # + # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ + # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ + # --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \ + # --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ + # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1 + + parser = ArgumentParser() + + parser.add_argument("--load_model", default="", type=str) # full path, with .pth + parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb + parser.add_argument("--proj_dir", default="out", type=str) + parser.add_argument("--random_seed", default="-1", type=int) + + parser.add_argument("--data_file", default="", type=str) + parser.add_argument("--data_type", default="utf-8", type=str) + parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) + + parser.add_argument("--ctx_len", default=1024, type=int) + parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps + parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs" + + parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) + parser.add_argument("--n_layer", default=6, type=int) + parser.add_argument("--n_embd", default=512, type=int) + parser.add_argument("--dim_att", default=0, type=int) + parser.add_argument("--dim_ffn", default=0, type=int) + parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) + parser.add_argument("--head_qk", default=0, type=int) # my headQK trick + parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim + parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer + + parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 + parser.add_argument("--lr_final", default=1e-5, type=float) + parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence + parser.add_argument("--adam_eps", default=1e-8, type=float) + + parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode + parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift + parser.add_argument("--my_pile_edecay", default=0, type=int) + parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough + # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) + + parser.add_argument("--my_img_version", default=0, type=str) + parser.add_argument("--my_img_size", default=0, type=int) + parser.add_argument("--my_img_bit", default=0, type=int) + parser.add_argument("--my_img_clip", default='x', type=str) + parser.add_argument("--my_img_clip_scale", default=1, type=float) + parser.add_argument("--my_img_l1_scale", default=0, type=float) + parser.add_argument("--my_img_encoder", default='x', type=str) + # parser.add_argument("--my_img_noise_scale", default=0, type=float) + parser.add_argument("--my_sample_len", default=0, type=int) + parser.add_argument("--my_ffn_shift", default=1, type=int) + parser.add_argument("--my_att_shift", default=1, type=int) + parser.add_argument("--my_pos_emb", default=0, type=int) + parser.add_argument("--load_partial", default=0, type=int) + parser.add_argument("--magic_prime", default=0, type=int) + parser.add_argument("--my_qa_mask", default=0, type=int) + parser.add_argument("--my_testing", default='', type=str) + + parser.add_argument("--lora", action="store_true") + parser.add_argument("--lora_load", default="", type=str) + parser.add_argument("--lora_r", default=8, type=int) + parser.add_argument("--lora_alpha", default=32, type=float) + parser.add_argument("--lora_dropout", default=0.01, type=float) + parser.add_argument("--lora_parts", default="att,ln,time", type=str) + + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args() + + ######################################################################################################## + + import os, warnings, math, datetime, sys, time, importlib + import numpy as np + import torch + from torch.utils.data import DataLoader + if "deepspeed" in args.strategy: + import deepspeed + import pytorch_lightning as pl + from pytorch_lightning import seed_everything + + if args.random_seed >= 0: + print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) + seed_everything(args.random_seed) + + np.set_printoptions(precision=4, suppress=True, linewidth=200) + warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") + warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") + # os.environ["WDS_SHOW_SEED"] = "1" + + args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + args.enable_checkpointing = False + args.replace_sampler_ddp = False + args.logger = False + args.gradient_clip_val = 1.0 + args.num_sanity_val_steps = 0 + args.check_val_every_n_epoch = int(1e20) + args.log_every_n_steps = int(1e20) + args.max_epochs = -1 # continue forever + args.betas = (args.beta1, args.beta2) + args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz + os.environ["RWKV_T_MAX"] = str(args.ctx_len) + os.environ["RWKV_MY_TESTING"] = args.my_testing + if args.dim_att <= 0: + args.dim_att = args.n_embd + if args.dim_ffn <= 0: + args.dim_ffn = args.n_embd * 4 + + if args.data_type == "wds_img": + args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" + args.proj_dir = f"{args.proj_dir}-{args.run_name}" + else: + args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + if not os.path.exists(args.proj_dir): + os.makedirs(args.proj_dir) + + if args.my_pile_stage > 0: + magic_prime_bak = args.magic_prime + if args.ctx_len == 1024: + args.magic_prime = 324331313 + args.epoch_count = 8043 + elif args.ctx_len == 2048: + args.magic_prime = 162165671 + args.epoch_count = 4021 + elif args.ctx_len == 4096: + args.magic_prime = 81082817 + args.epoch_count = 2010 + if args.my_pile_shift < 0: + if args.ctx_len == 1024: + args.my_pile_shift = 0 + elif args.ctx_len == 2048: + args.my_pile_shift = 512 + elif args.ctx_len == 4096: + args.my_pile_shift = 768 + + if magic_prime_bak > 0: + args.magic_prime = magic_prime_bak + + args.epoch_steps = 40320 // args.real_bsz + assert args.epoch_steps * args.real_bsz == 40320 + if args.my_pile_stage == 2: + assert args.lr_final == args.lr_init + if args.my_pile_stage >= 2: # find latest saved model + list_p = [] + for p in os.listdir(args.proj_dir): + if p.startswith("rwkv") and p.endswith(".pth"): + p = ((p.split("-"))[1].split("."))[0] + if p == "init": + p = -1 + else: + p = int(p) + list_p += [p] + list_p.sort() + max_p = list_p[-1] + if len(list_p) > 1: + args.my_pile_prev_p = list_p[-2] # in case max_p is corrupted + if max_p == -1: + args.load_model = f"{args.proj_dir}/rwkv-init.pth" + else: + args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" + if args.my_pile_stage == 2: + args.warmup_steps = 10 + else: + args.warmup_steps = 30 + args.epoch_begin = max_p + 1 + + samples_per_epoch = args.epoch_steps * args.real_bsz + tokens_per_epoch = samples_per_epoch * args.ctx_len + rank_zero_info( + f""" +############################################################################ +# +# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} +# +# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} +# +# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch +# +# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens +# +# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len +# LoRA = {f'enabled, {args.lora_r} r, {args.lora_alpha} alpha, {args.lora_dropout} dropout, on {args.lora_parts}' if args.lora else 'disabled'} +# +# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} +# +# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer +# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions) +# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer +# +############################################################################ +""" + ) + rank_zero_info(str(vars(args)) + "\n") + + assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] + + if args.lr_final == 0 or args.lr_init == 0: + rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") + + assert args.precision in ["fp32", "tf32", "fp16", "bf16"] + os.environ["RWKV_FLOAT_MODE"] = args.precision + if args.precision == "fp32": + for i in range(10): + rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") + if args.precision == "fp16": + rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") + + os.environ["RWKV_JIT_ON"] = "1" + if "deepspeed_stage_3" in args.strategy: + os.environ["RWKV_JIT_ON"] = "0" + if args.lora and args.grad_cp == 1: + print('!!!!! LoRA Warning: Gradient Checkpointing requires JIT off, disabling it') + os.environ["RWKV_JIT_ON"] = "0" + + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True + if args.precision == "fp32": + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + else: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + if "32" in args.precision: + args.precision = 32 + elif args.precision == "fp16": + args.precision = 16 + else: + args.precision = "bf16" + + ######################################################################################################## + + from src.trainer import train_callback, generate_init_weight + from src.dataset import MyDataset + + train_data = MyDataset(args) + args.vocab_size = train_data.vocab_size + + if args.data_type == 'wds_img': + from src.model_img import RWKV_IMG + assert args.lora, "LoRA not yet supported for RWKV_IMG" + model = RWKV_IMG(args) + else: + from src.model import RWKV, LORA_CONFIG, LoraLinear + if args.lora: + assert args.lora_r > 0, "LoRA should have its `r` > 0" + LORA_CONFIG["r"] = args.lora_r + LORA_CONFIG["alpha"] = args.lora_alpha + LORA_CONFIG["dropout"] = args.lora_dropout + LORA_CONFIG["parts"] = set(str(args.lora_parts).split(',')) + enable_time_finetune = 'time' in LORA_CONFIG["parts"] + enable_ln_finetune = 'ln' in LORA_CONFIG["parts"] + model = RWKV(args) + # only train lora parameters + if args.lora: + model.requires_grad_(False) + for name, module in model.named_modules(): + # have to check param name since it may have been wrapped by torchscript + if any(n.startswith("lora_") for n, _ in module.named_parameters()): + print(f' LoRA training module {name}') + for pname, param in module.named_parameters(): + param.requires_grad = 'lora_' in pname + elif enable_ln_finetune and '.ln' in name: + print(f' LoRA additionally training module {name}') + for param in module.parameters(): + param.requires_grad = True + elif enable_time_finetune and any(n.startswith("time") for n, _ in module.named_parameters()): + for pname, param in module.named_parameters(): + if pname.startswith("time"): + print(f' LoRA additionally training parameter {pname}') + param.requires_grad = True + + if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights? + init_weight_name = f"{args.proj_dir}/rwkv-init.pth" + generate_init_weight(model, init_weight_name) # save initial weights + args.load_model = init_weight_name + + rank_zero_info(f"########## Loading {args.load_model}... ##########") + try: + load_dict = torch.load(args.load_model, map_location="cpu") + except: + rank_zero_info(f"Bad checkpoint {args.load_model}") + if args.my_pile_stage >= 2: # try again using another checkpoint + max_p = args.my_pile_prev_p + if max_p == -1: + args.load_model = f"{args.proj_dir}/rwkv-init.pth" + else: + args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" + args.epoch_begin = max_p + 1 + rank_zero_info(f"Trying {args.load_model}") + load_dict = torch.load(args.load_model, map_location="cpu") + + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in model.state_dict(): + if k not in load_keys: + load_dict[k] = model.state_dict()[k] + # If using LoRA, the LoRA keys might be missing in the original model + model.load_state_dict(load_dict, strict=(not args.lora)) + if os.path.isfile(args.lora_load): + model.load_state_dict(torch.load(args.lora_load, map_location="cpu"), + strict=False) + + trainer: Trainer = Trainer.from_argparse_args( + args, + callbacks=[train_callback(args)], + ) + + if (args.lr_init > 1e-4 or trainer.world_size * args.micro_bsz * trainer.accumulate_grad_batches < 8): + if 'I_KNOW_WHAT_IM_DOING' in os.environ: + if trainer.global_rank == 0: + print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') + print(f' WARNING: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)') + print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') + else: + if trainer.global_rank == 0: + print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') + print(f' ERROR: you are using too large LR ({args.lr_init} > 1e-4) or too small global batch size ({trainer.world_size} * {args.micro_bsz} * {trainer.accumulate_grad_batches} < 8)') + print(f' Unless you are sure this is what you want, adjust them accordingly') + print(f' (to suppress this, set environment variable "I_KNOW_WHAT_IM_DOING")') + print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') + exit(0) + + if trainer.global_rank == 0: + for n in model.state_dict(): + shape = model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") + + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + + # must set shuffle=False, persistent_workers=False (because worker is in another thread) + data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) + + trainer.fit(model, data_loader) diff --git a/finetune/requirements.txt b/finetune/requirements.txt new file mode 100644 index 0000000..6342244 --- /dev/null +++ b/finetune/requirements.txt @@ -0,0 +1,3 @@ +torch==1.13.1 +pytorch_lightning==1.9.5 +deepspeed diff --git a/frontend/package-lock.json b/frontend/package-lock.json index a83d946..c039c98 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -12,6 +12,7 @@ "@fluentui/react-icons": "^2.0.201", "@microsoft/fetch-event-source": "^2.0.1", "@primer/octicons-react": "^19.1.0", + "chart.js": "^4.3.0", "classnames": "^2.3.2", "github-markdown-css": "^5.2.0", "i18next": "^22.4.15", @@ -19,6 +20,7 @@ "mobx-react-lite": "^3.4.3", "react": "^18.2.0", "react-beautiful-dnd": "^13.1.1", + "react-chartjs-2": "^5.2.0", "react-dom": "^18.2.0", "react-i18next": "^12.2.2", "react-markdown": "^8.0.7", @@ -1903,6 +1905,11 @@ "integrity": "sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw==", "dev": true }, + "node_modules/@kurkle/color": { + "version": "0.3.2", + "resolved": "https://registry.npmjs.org/@kurkle/color/-/color-0.3.2.tgz", + "integrity": "sha512-fuscdXJ9G1qb7W8VdHi+IwRqij3lBkosAm4ydQtEmbY58OzHXqQhvlxqEkoz0yssNVn38bcpRWgA9PP+OGoisw==" + }, "node_modules/@microsoft/fetch-event-source": { "version": "2.0.1", "resolved": "https://registry.npmmirror.com/@microsoft/fetch-event-source/-/fetch-event-source-2.0.1.tgz", @@ -2258,6 +2265,17 @@ "resolved": "https://registry.npmmirror.com/character-entities/-/character-entities-2.0.2.tgz", "integrity": "sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ==" }, + "node_modules/chart.js": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/chart.js/-/chart.js-4.3.0.tgz", + "integrity": "sha512-ynG0E79xGfMaV2xAHdbhwiPLczxnNNnasrmPEXriXsPJGjmhOBYzFVEsB65w2qMDz+CaBJJuJD0inE/ab/h36g==", + "dependencies": { + "@kurkle/color": "^0.3.0" + }, + "engines": { + "pnpm": ">=7" + } + }, "node_modules/chokidar": { "version": "3.5.3", "resolved": "https://registry.npmmirror.com/chokidar/-/chokidar-3.5.3.tgz", @@ -3884,6 +3902,15 @@ "react-dom": "^16.8.5 || ^17.0.0 || ^18.0.0" } }, + "node_modules/react-chartjs-2": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/react-chartjs-2/-/react-chartjs-2-5.2.0.tgz", + "integrity": "sha512-98iN5aguJyVSxp5U3CblRLH67J8gkfyGNbiK3c+l1QI/G4irHMPQw44aEPmjVag+YKTyQ260NcF82GTQ3bdscA==", + "peerDependencies": { + "chart.js": "^4.1.1", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0" + } + }, "node_modules/react-dom": { "version": "18.2.0", "resolved": "https://registry.npmmirror.com/react-dom/-/react-dom-18.2.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index 464ea14..dd13e6a 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -13,6 +13,7 @@ "@fluentui/react-icons": "^2.0.201", "@microsoft/fetch-event-source": "^2.0.1", "@primer/octicons-react": "^19.1.0", + "chart.js": "^4.3.0", "classnames": "^2.3.2", "github-markdown-css": "^5.2.0", "i18next": "^22.4.15", @@ -20,6 +21,7 @@ "mobx-react-lite": "^3.4.3", "react": "^18.2.0", "react-beautiful-dnd": "^13.1.1", + "react-chartjs-2": "^5.2.0", "react-dom": "^18.2.0", "react-i18next": "^12.2.2", "react-markdown": "^8.0.7", diff --git a/frontend/src/_locales/zh-hans/main.json b/frontend/src/_locales/zh-hans/main.json index e011b60..42d8148 100644 --- a/frontend/src/_locales/zh-hans/main.json +++ b/frontend/src/_locales/zh-hans/main.json @@ -189,5 +189,37 @@ "user": "用户", "assistant": "AI", "system": "系统", - "Regenerate": "重新生成" + "Regenerate": "重新生成", + "LoRA Finetune": "LoRA微调", + "Command Stopped": "命令已终止", + "Please convert data first.": "请先转换数据", + "Ubuntu is not installed, do you want to install it?": "Ubuntu未安装,是否安装?", + "Install Ubuntu": "安装Ubuntu", + "Please install Ubuntu using Microsoft Store": "请用Microsoft Store安装Ubuntu", + "WSL is not enabled, do you want to enable it?": "WSL未启用,是否启用?", + "Enable WSL": "启用WSL", + "After installation, please restart your computer to enable WSL": "安装完成后,请重启电脑以启用WSL", + "Data Process": "数据处理", + "Data Path": "数据路径", + "Vocab Path": "词表路径", + "Train Parameters": "训练参数", + "Base Model": "基底模型", + "LoRA Model": "LoRA模型", + "Merge Model": "合并模型", + "Devices": "显卡数量", + "Gradient Checkpoint": "梯度检查点标志", + "Context Length": "上下文长度", + "Epoch Steps": "每轮训练步数", + "Epoch Count": "训练轮次", + "Epoch Begin": "起始轮次", + "Epoch Save": "保存间隔轮次", + "Learning Rate Init": "初始学习率", + "Learning Rate Final": "最终学习率", + "Micro Batch Size": "微批次大小", + "Accumulate Gradient Batches": "梯度累积批次", + "Warmup Steps": "学习率预热步数", + "Pre-FFN": "前馈网络预处理", + "None": "空", + "Merge model successfully": "合并模型成功", + "Convert Data successfully": "数据转换成功" } \ No newline at end of file diff --git a/frontend/src/pages/Train.tsx b/frontend/src/pages/Train.tsx index 332f184..dfb7dde 100644 --- a/frontend/src/pages/Train.tsx +++ b/frontend/src/pages/Train.tsx @@ -1,13 +1,542 @@ -import React, { FC } from 'react'; -import { Text } from '@fluentui/react-components'; +import React, { FC, ReactElement, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { Button, Dropdown, Input, Option, Select, Switch, Tab, TabList } from '@fluentui/react-components'; +import { + ConvertData, + FileExists, + MergeLora, + OpenFileFolder, + WslCommand, + WslEnable, + WslInstallUbuntu, + WslIsEnabled, + WslStart, + WslStop +} from '../../wailsjs/go/backend_golang/App'; +import { toast } from 'react-toastify'; +import commonStore from '../stores/commonStore'; +import { observer } from 'mobx-react-lite'; +import { SelectTabEventHandler } from '@fluentui/react-tabs'; +import { refreshLocalModels, toastWithButton } from '../utils'; +import { Section } from '../components/Section'; +import { Labeled } from '../components/Labeled'; +import { ToolTipButton } from '../components/ToolTipButton'; +import { DataUsageSettings20Regular, Folder20Regular } from '@fluentui/react-icons'; +import { useNavigate } from 'react-router'; +import { Precision } from './Configs'; +import { + CategoryScale, + Chart as ChartJS, + Legend, + LinearScale, + LineElement, + PointElement, + Title, + Tooltip +} from 'chart.js'; +import { Line } from 'react-chartjs-2'; +import { ChartJSOrUndefined } from 'react-chartjs-2/dist/types'; + +ChartJS.register( + CategoryScale, + LinearScale, + PointElement, + LineElement, + Tooltip, + Title, + Legend +); + +const parseLossData = (data: string) => { + const regex = /Epoch (\d+):\s+(\d+%)\|[\s\S]*\| (\d+)\/(\d+) \[(\d+:\d+)<(\d+:\d+),\s+(\d+.\d+it\/s), loss=(\d+.\d+),[\s\S]*\]/g; + const matches = Array.from(data.matchAll(regex)); + if (matches.length === 0) + return; + const lastMatch = matches[matches.length - 1]; + const epoch = parseInt(lastMatch[1]); + const loss = parseFloat(lastMatch[8]); + commonStore.setChartTitle(`Epoch ${epoch}: ${lastMatch[2]} - ${lastMatch[3]}/${lastMatch[4]} - ${lastMatch[5]}/${lastMatch[6]} - ${lastMatch[7]} Loss=${loss}`); + addLossDataToChart(epoch, loss); +}; + +let chartLine: ChartJSOrUndefined<'line', (number | null)[], string>; + +const addLossDataToChart = (epoch: number, loss: number) => { + const epochIndex = commonStore.chartData.labels!.findIndex(l => l.includes(epoch.toString())); + if (epochIndex === -1) { + if (epoch === 0) { + commonStore.chartData.labels!.push('Init'); + commonStore.chartData.datasets[0].data = [...commonStore.chartData.datasets[0].data, loss]; + } + commonStore.chartData.labels!.push('Epoch ' + epoch.toString()); + commonStore.chartData.datasets[0].data = [...commonStore.chartData.datasets[0].data, loss]; + } else { + if (chartLine) { + const newData = [...commonStore.chartData.datasets[0].data]; + newData[epochIndex] = loss; + chartLine.data.datasets[0].data = newData; + chartLine.update(); + } + } + commonStore.setChartData(commonStore.chartData); +}; + +export type DataProcessParameters = { + dataPath: string; + vocabPath: string; +} + +export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'fp32' | 'tf32'; + +export type LoraFinetuneParameters = { + baseModel: string; + ctxLen: number; + epochSteps: number; + epochCount: number; + epochBegin: number; + epochSave: number; + microBsz: number; + accumGradBatches: number; + preFfn: boolean; + headQk: boolean; + lrInit: string; + lrFinal: string; + warmupSteps: number; + beta1: number; + beta2: number; + adamEps: string; + devices: number; + precision: LoraFinetunePrecision; + gradCp: boolean; + loraR: number; + loraAlpha: number; + loraDropout: number; + loraLoad: string +} + +const loraFinetuneParametersOptions: Array<[key: keyof LoraFinetuneParameters, type: string, name: string]> = [ + ['devices', 'number', 'Devices'], + ['precision', 'LoraFinetunePrecision', 'Precision'], + ['gradCp', 'boolean', 'Gradient Checkpoint'], + ['ctxLen', 'number', 'Context Length'], + ['epochSteps', 'number', 'Epoch Steps'], + ['epochCount', 'number', 'Epoch Count'], + ['epochBegin', 'number', 'Epoch Begin'], + ['epochSave', 'number', 'Epoch Save'], + ['lrInit', 'string', 'Learning Rate Init'], + ['lrFinal', 'string', 'Learning Rate Final'], + ['microBsz', 'number', 'Micro Batch Size'], + ['accumGradBatches', 'number', 'Accumulate Gradient Batches'], + ['warmupSteps', 'number', 'Warmup Steps'], + ['adamEps', 'string', 'Adam Epsilon'], + ['beta1', 'number', 'Beta 1'], + ['beta2', 'number', 'Beta 2'], + ['loraR', 'number', 'LoRA R'], + ['loraAlpha', 'number', 'LoRA Alpha'], + ['loraDropout', 'number', 'LoRA Dropout'], + ['beta1', 'any', ''], + ['preFfn', 'boolean', 'Pre-FFN'], + ['headQk', 'boolean', 'Head QK'] +]; + +export const wslHandler = (data: string) => { + if (data) { + addWslMessage(data); + parseLossData(data); + } +}; + +const addWslMessage = (message: string) => { + const newData = commonStore.wslStdout + '\n' + message; + let lines = newData.split('\n'); + const result = lines.slice(-100).join('\n'); + commonStore.setWslStdout(result); +}; + +const TerminalDisplay: FC = observer(() => { + const bodyRef = useRef(null); + + const scrollToBottom = () => { + if (bodyRef.current) + bodyRef.current.scrollTop = bodyRef.current.scrollHeight; + }; + + useEffect(() => { + scrollToBottom(); + }); + + return ( +
+
+ {commonStore.wslStdout} +
+
+ ); +}); + +const Terminal: FC = observer(() => { + const { t } = useTranslation(); + const [input, setInput] = useState(''); + + const handleKeyDown = (e: any) => { + e.stopPropagation(); + if (e.keyCode === 13) { + e.preventDefault(); + if (!input) return; + + WslStart().then(() => { + addWslMessage('WSL> ' + input); + setInput(''); + WslCommand(input).catch((e) => { + toast((e.message || e), { type: 'error' }); + }); + }).catch((e) => { + toast((e.message || e), { type: 'error' }); + }); + } + }; + + return ( +
+ +
+ WSL: + { + setInput(e.target.value); + }} onKeyDown={handleKeyDown}> + +
+
+ ); +}); + +const LoraFinetune: FC = observer(() => { + const { t } = useTranslation(); + const navigate = useNavigate(); + const chartRef = useRef>(null); + + const dataParams = commonStore.dataProcessParams; + const loraParams = commonStore.loraFinetuneParams; + + if (chartRef.current) + chartLine = chartRef.current; + + const setDataParams = (newParams: Partial) => { + commonStore.setDataProcessParams({ + ...dataParams, + ...newParams + }); + }; + + const setLoraParams = (newParams: Partial) => { + commonStore.setLoraFinetuneParameters({ + ...loraParams, + ...newParams + }); + }; + + useEffect(() => { + if (loraParams.baseModel === '') + setLoraParams({ + baseModel: commonStore.modelSourceList.find(m => m.isComplete)?.name || '' + }); + }, []); + + const StartLoraFinetune = () => { + WslIsEnabled().then(() => { + WslStart().then(async () => { + const convertedDataPath = `./finetune/json2binidx_tool/data/${dataParams.dataPath.split('/').pop()!.split('.')[0]}_text_document`; + if (!await FileExists(convertedDataPath + '.idx')) { + toast(t('Please convert data first.'), { type: 'error' }); + return; + } + + commonStore.setChartData({ + labels: [], + datasets: [ + { + label: 'Loss', + data: [], + borderColor: 'rgb(53, 162, 235)', + backgroundColor: 'rgba(53, 162, 235, 0.5)' + } + ] + }); + WslCommand(`export cnMirror=${commonStore.settings.cnMirror ? '1' : '0'} ` + + `&& export loadModel=models/${loraParams.baseModel} ` + + `&& chmod +x finetune/install-wsl-dep-and-train.sh && ./finetune/install-wsl-dep-and-train.sh ` + + (loraParams.baseModel ? `--load_model models/${loraParams.baseModel} ` : '') + + (loraParams.loraLoad ? `--lora_load lora-models/${loraParams.loraLoad} ` : '') + + `--data_file ${convertedDataPath} ` + + `--vocab_size ${loraParams.baseModel.toLowerCase().includes('world') ? '65536' : '50277'} ` + + `--ctx_len ${loraParams.ctxLen} --epoch_steps ${loraParams.epochSteps} --epoch_count ${loraParams.epochCount} ` + + `--epoch_begin ${loraParams.epochBegin} --epoch_save ${loraParams.epochSave} ` + + `--micro_bsz ${loraParams.microBsz} --accumulate_grad_batches ${loraParams.accumGradBatches} ` + + `--pre_ffn ${loraParams.preFfn ? '1' : '0'} --head_qk ${loraParams.headQk ? '1' : '0'} --lr_init ${loraParams.lrInit} --lr_final ${loraParams.lrFinal} ` + + `--warmup_steps ${loraParams.warmupSteps} ` + + `--beta1 ${loraParams.beta1} --beta2 ${loraParams.beta2} --adam_eps ${loraParams.adamEps} ` + + `--devices ${loraParams.devices} --precision ${loraParams.precision} ` + + `--grad_cp ${loraParams.gradCp ? '1' : '0'} ` + + `--lora_r ${loraParams.loraR} --lora_alpha ${loraParams.loraAlpha} --lora_dropout ${loraParams.loraDropout}`).catch((e) => { + toast((e.message || e), { type: 'error' }); + }); + }).catch(e => { + const msg = e.message || e; + if (msg === 'ubuntu not found') { + toastWithButton(t('Ubuntu is not installed, do you want to install it?'), t('Install Ubuntu'), () => { + WslInstallUbuntu().then(() => { + toast(t('Please install Ubuntu using Microsoft Store'), { type: 'info', autoClose: 6000 }); + }); + }); + } + }); + }).catch(e => { + const msg = e.message || e; + + const enableWsl = (forceMode: boolean) => { + toastWithButton(t('WSL is not enabled, do you want to enable it?'), t('Enable WSL'), () => { + WslEnable(forceMode).then(() => { + toast(t('After installation, please restart your computer to enable WSL'), { + type: 'info', + autoClose: false + }); + }).catch(e => { + toast((e.message || e), { type: 'error' }); + }); + }); + }; + + if (msg === 'wsl is not enabled') { + enableWsl(false); + } else if (msg.includes('wsl.state: The system cannot find the file')) { + enableWsl(true); + } else { + toast(msg, { type: 'error' }); + } + }); + }; + + return ( +
+ {(commonStore.wslStdout.length > 0 || commonStore.chartData.labels!.length !== 0) && +
+ {commonStore.wslStdout.length > 0 && commonStore.chartData.labels!.length === 0 && } + {commonStore.chartData.labels!.length !== 0 && + } +
+ } +
+
+ + { + setDataParams({ dataPath: data.value }); + }} /> + } onClick={() => { + OpenFileFolder(dataParams.dataPath, false); + }} /> +
+ } /> +
+ {t('Vocab Path')} + { + setDataParams({ vocabPath: data.value }); + }} /> + +
+
+ } + /> + +
+
+ {t('Base Model')} + + } onClick={() => { + navigate({ pathname: '/models' }); + }} /> +
+
+ {t('LoRA Model')} + + +
+ { + loraFinetuneParametersOptions.map(([key, type, name], index) => { + return ( + { + setLoraParams({ + [key]: Number(data.value) + }); + }} /> : + type === 'boolean' ? + { + setLoraParams({ + [key]: data.checked + }); + }} /> : + type === 'string' ? + { + setLoraParams({ + [key]: data.value + }); + }} /> : + type === 'LoraFinetunePrecision' ? + { + if (data.optionText) { + setLoraParams({ + precision: data.optionText as LoraFinetunePrecision + }); + } + }} + > + + + + + + :
+ } /> + ); + }) + } +
+ } + /> +
+
+
+ + +
+
+ ); +}); + +type TrainNavigationItem = { + element: ReactElement; +}; + +const pages: { [label: string]: TrainNavigationItem } = { + 'LoRA Finetune': { + element: + }, + WSL: { + element: + } +}; + export const Train: FC = () => { const { t } = useTranslation(); + const [tab, setTab] = useState('LoRA Finetune'); - return ( -
- {t('In Development')} + const selectTab: SelectTabEventHandler = (e, data) => + typeof data.value === 'string' ? setTab(data.value) : null; + + return
+ + {Object.entries(pages).map(([label]) => ( + + {t(label)} + + ))} + +
+ {pages[tab].element}
- ); +
; }; diff --git a/frontend/src/startup.ts b/frontend/src/startup.ts index 9fe7ead..b6f69c0 100644 --- a/frontend/src/startup.ts +++ b/frontend/src/startup.ts @@ -1,11 +1,12 @@ import commonStore, { Platform } from './stores/commonStore'; -import { GetPlatform, ReadJson } from '../wailsjs/go/backend_golang/App'; +import { GetPlatform, ListDirFiles, ReadJson } from '../wailsjs/go/backend_golang/App'; import { Cache, checkUpdate, downloadProgramFiles, LocalConfig, refreshModels } from './utils'; import { getStatus } from './apis'; import { EventsOn } from '../wailsjs/runtime'; import manifest from '../../manifest.json'; import { defaultModelConfigs, defaultModelConfigsMac } from './pages/defaultModelConfigs'; import { Preset } from './pages/PresetsManager/PresetsButton'; +import { wslHandler } from './pages/Train'; export async function startup() { downloadProgramFiles(); @@ -13,9 +14,14 @@ export async function startup() { if (data) commonStore.setDownloadList(data); }); + EventsOn('wsl', wslHandler); + EventsOn('wslerr', (e) => { + console.log(e); + }); + initLoraModels(); initPresets(); - + await GetPlatform().then(p => commonStore.setPlatform(p as Platform)); await initConfig(); @@ -50,6 +56,9 @@ async function initConfig() { if (configData.settings) commonStore.setSettings(configData.settings, false); + if (configData.dataProcessParams) + commonStore.setDataProcessParams(configData.dataProcessParams, false); + if (configData.modelConfigs && Array.isArray(configData.modelConfigs)) commonStore.setModelConfigs(configData.modelConfigs, false); else throw new Error('Invalid config.json'); @@ -76,3 +85,24 @@ async function initPresets() { }).catch(() => { }); } + +async function initLoraModels() { + const refreshLoraModels = () => { + ListDirFiles('lora-models').then((data) => { + if (!data) return; + const loraModels = []; + for (const f of data) { + if (!f.isDir && f.name.endsWith('.pth')) { + loraModels.push(f.name); + } + } + commonStore.setLoraModels(loraModels); + }); + }; + + refreshLoraModels(); + EventsOn('fsnotify', (data: string) => { + if (data.includes('lora-models')) + refreshLoraModels(); + }); +} diff --git a/frontend/src/stores/commonStore.ts b/frontend/src/stores/commonStore.ts index 30b17fc..bcbcaa2 100644 --- a/frontend/src/stores/commonStore.ts +++ b/frontend/src/stores/commonStore.ts @@ -14,6 +14,8 @@ import { CompletionPreset } from '../pages/Completion'; import { defaultModelConfigs, defaultModelConfigsMac } from '../pages/defaultModelConfigs'; import commonStore from './commonStore'; import { Preset } from '../pages/PresetsManager/PresetsButton'; +import { DataProcessParameters, LoraFinetuneParameters } from '../pages/Train'; +import { ChartData } from 'chart.js'; export enum ModelStatus { Offline, @@ -30,6 +32,8 @@ export type Status = { export type Platform = 'windows' | 'darwin' | 'linux'; +const labels = ['January', 'February', 'March', 'April', 'May', 'June', 'July']; + class CommonStore { // global status: Status = { @@ -62,6 +66,40 @@ class CommonStore { // downloads downloadList: DownloadStatus[] = []; lastUnfinishedModelDownloads: DownloadStatus[] = []; + // train + wslStdout: string = ''; + chartTitle: string = ''; + chartData: ChartData<'line', (number | null)[], string> = { labels: [], datasets: [] }; + loraModels: string[] = []; + dataProcessParams: DataProcessParameters = { + dataPath: 'finetune/data/sample.jsonl', + vocabPath: 'backend-python/rwkv_pip/rwkv_vocab_v20230424.txt' + }; + loraFinetuneParams: LoraFinetuneParameters = { + baseModel: '', + ctxLen: 1024, + epochSteps: 1000, + epochCount: 20, + epochBegin: 0, + epochSave: 5, + microBsz: 1, + accumGradBatches: 8, + preFfn: false, + headQk: false, + lrInit: '5e-5', + lrFinal: '5e-5', + warmupSteps: 0, + beta1: 0.9, + beta2: 0.999, + adamEps: '1e-8', + devices: 1, + precision: 'bf16', + gradCp: false, + loraR: 8, + loraAlpha: 32, + loraDropout: 0.01, + loraLoad: '' + }; // settings advancedCollapsed: boolean = true; settings: SettingsType = { @@ -228,6 +266,34 @@ class CommonStore { setCompletionSubmittedPrompt(value: string) { this.completionSubmittedPrompt = value; } + + setWslStdout(value: string) { + this.wslStdout = value; + } + + setDataProcessParams(value: DataProcessParameters, saveConfig: boolean = true) { + this.dataProcessParams = value; + if (saveConfig) + saveConfigs(); + } + + setLoraFinetuneParameters(value: LoraFinetuneParameters, saveConfig: boolean = true) { + this.loraFinetuneParams = value; + if (saveConfig) + saveConfigs(); + } + + setChartTitle(value: string) { + this.chartTitle = value; + } + + setChartData(value: ChartData<'line', (number | null)[], string>) { + this.chartData = value; + } + + setLoraModels(value: string[]) { + this.loraModels = value; + } } export default new CommonStore(); \ No newline at end of file diff --git a/frontend/src/utils/index.tsx b/frontend/src/utils/index.tsx index 700d4fc..5a7adc9 100644 --- a/frontend/src/utils/index.tsx +++ b/frontend/src/utils/index.tsx @@ -17,6 +17,7 @@ import { Language, Languages, SettingsType } from '../pages/Settings'; import { ModelSourceItem } from '../pages/Models'; import { ModelConfig, ModelParameters } from '../pages/Configs'; import { DownloadStatus } from '../pages/Downloads'; +import { DataProcessParameters, LoraFinetuneParameters } from '../pages/Train'; export type Cache = { version: string @@ -28,7 +29,9 @@ export type LocalConfig = { modelSourceManifestList: string currentModelConfigIndex: number modelConfigs: ModelConfig[] - settings: SettingsType + settings: SettingsType, + dataProcessParams: DataProcessParameters, + loraFinetuneParams: LoraFinetuneParameters } export async function refreshBuiltInModels(readCache: boolean = false) { @@ -194,7 +197,9 @@ export const saveConfigs = async () => { modelSourceManifestList: commonStore.modelSourceManifestList, currentModelConfigIndex: commonStore.currentModelConfigIndex, modelConfigs: commonStore.modelConfigs, - settings: commonStore.settings + settings: commonStore.settings, + dataProcessParams: commonStore.dataProcessParams, + loraFinetuneParams: commonStore.loraFinetuneParams }; return SaveJson('config.json', data); }; diff --git a/frontend/wailsjs/go/backend_golang/App.d.ts b/frontend/wailsjs/go/backend_golang/App.d.ts index b7b62be..172529a 100755 --- a/frontend/wailsjs/go/backend_golang/App.d.ts +++ b/frontend/wailsjs/go/backend_golang/App.d.ts @@ -6,6 +6,8 @@ export function AddToDownloadList(arg1:string,arg2:string):Promise; export function ContinueDownload(arg1:string):Promise; +export function ConvertData(arg1:string,arg2:string,arg3:string,arg4:string):Promise; + export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Promise; export function CopyFile(arg1:string,arg2:string):Promise; @@ -24,6 +26,8 @@ export function InstallPyDep(arg1:string,arg2:boolean):Promise; export function ListDirFiles(arg1:string):Promise>; +export function MergeLora(arg1:string,arg2:boolean,arg3:number,arg4:string,arg5:string,arg6:string):Promise; + export function OpenFileFolder(arg1:string,arg2:boolean):Promise; export function OpenSaveFileDialog(arg1:string,arg2:string,arg3:string):Promise; @@ -41,3 +45,15 @@ export function SaveJson(arg1:string,arg2:any):Promise; export function StartServer(arg1:string,arg2:number,arg3:string):Promise; export function UpdateApp(arg1:string):Promise; + +export function WslCommand(arg1:string):Promise; + +export function WslEnable(arg1:boolean):Promise; + +export function WslInstallUbuntu():Promise; + +export function WslIsEnabled():Promise; + +export function WslStart():Promise; + +export function WslStop():Promise; diff --git a/frontend/wailsjs/go/backend_golang/App.js b/frontend/wailsjs/go/backend_golang/App.js index ea9ef66..d1c629e 100755 --- a/frontend/wailsjs/go/backend_golang/App.js +++ b/frontend/wailsjs/go/backend_golang/App.js @@ -10,6 +10,10 @@ export function ContinueDownload(arg1) { return window['go']['backend_golang']['App']['ContinueDownload'](arg1); } +export function ConvertData(arg1, arg2, arg3, arg4) { + return window['go']['backend_golang']['App']['ConvertData'](arg1, arg2, arg3, arg4); +} + export function ConvertModel(arg1, arg2, arg3, arg4) { return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3, arg4); } @@ -46,6 +50,10 @@ export function ListDirFiles(arg1) { return window['go']['backend_golang']['App']['ListDirFiles'](arg1); } +export function MergeLora(arg1, arg2, arg3, arg4, arg5, arg6) { + return window['go']['backend_golang']['App']['MergeLora'](arg1, arg2, arg3, arg4, arg5, arg6); +} + export function OpenFileFolder(arg1, arg2) { return window['go']['backend_golang']['App']['OpenFileFolder'](arg1, arg2); } @@ -81,3 +89,27 @@ export function StartServer(arg1, arg2, arg3) { export function UpdateApp(arg1) { return window['go']['backend_golang']['App']['UpdateApp'](arg1); } + +export function WslCommand(arg1) { + return window['go']['backend_golang']['App']['WslCommand'](arg1); +} + +export function WslEnable(arg1) { + return window['go']['backend_golang']['App']['WslEnable'](arg1); +} + +export function WslInstallUbuntu() { + return window['go']['backend_golang']['App']['WslInstallUbuntu'](); +} + +export function WslIsEnabled() { + return window['go']['backend_golang']['App']['WslIsEnabled'](); +} + +export function WslStart() { + return window['go']['backend_golang']['App']['WslStart'](); +} + +export function WslStop() { + return window['go']['backend_golang']['App']['WslStop'](); +} diff --git a/go.mod b/go.mod index 958411f..300ae2d 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,14 @@ go 1.20 require ( github.com/cavaliergopher/grab/v3 v3.0.1 github.com/minio/selfupdate v0.6.0 + github.com/ubuntu/gowsl v0.0.0-20230615094051-94945650cc1e github.com/wailsapp/wails/v2 v2.5.1 ) require ( aead.dev/minisign v0.2.0 // indirect github.com/bep/debounce v1.2.1 // indirect + github.com/fsnotify/fsnotify v1.6.0 github.com/go-ole/go-ole v1.2.6 // indirect github.com/google/uuid v1.3.0 // indirect github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect @@ -21,17 +23,20 @@ require ( github.com/leaanthony/slicer v1.6.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.18 // indirect + github.com/nyaosorg/go-windows-su v0.2.1 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/samber/lo v1.38.1 // indirect + github.com/sirupsen/logrus v1.9.0 // indirect github.com/tkrajina/go-reflector v0.5.6 // indirect + github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect github.com/wailsapp/mimetype v1.4.1 // indirect golang.org/x/crypto v0.9.0 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect + golang.org/x/sys v0.9.0 // indirect golang.org/x/text v0.9.0 // indirect ) diff --git a/go.sum b/go.sum index ab625bc..cd6bb56 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ aead.dev/minisign v0.2.0 h1:kAWrq/hBRu4AARY6AlciO83xhNnW9UaC8YipS2uhLPk= aead.dev/minisign v0.2.0/go.mod h1:zdq6LdSd9TbuSxchxwhpA9zEb9YXcVGoE8JakuiGaIQ= +github.com/0xrawsec/golang-utils v1.3.2 h1:ww4jrtHRSnX9xrGzJYbalx5nXoZewy4zPxiY+ubJgtg= github.com/bep/debounce v1.2.1 h1:v67fRdBA9UQu2NhLFXrSg0Brw7CexQekrBwDMM8bzeY= github.com/bep/debounce v1.2.1/go.mod h1:H8yggRPQKLUhUoqrJC1bO2xNya7vanpDl7xR3ISbCJ0= github.com/cavaliergopher/grab/v3 v3.0.1 h1:4z7TkBfmPjmLAAmkkAZNX/6QJ1nNFdv3SdIHXju0Fr4= @@ -7,6 +8,8 @@ github.com/cavaliergopher/grab/v3 v3.0.1/go.mod h1:1U/KNnD+Ft6JJiYoYBAimKH2XrYpt github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= @@ -37,6 +40,8 @@ github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp9 github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU= github.com/minio/selfupdate v0.6.0/go.mod h1:bO02GTIPCMQFTEvE5h4DjYB58bCoZ35XLeBf0buTDdM= +github.com/nyaosorg/go-windows-su v0.2.1 h1:5V0XavLyjOqPUp7psxxCvBISaneU4XmFPSMlejSl5sc= +github.com/nyaosorg/go-windows-su v0.2.1/go.mod h1:fWKxSCXwGuDuW6ne0kLp/Cj0joXNDDw01G3LseQJYS0= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -48,11 +53,17 @@ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/tkrajina/go-reflector v0.5.6 h1:hKQ0gyocG7vgMD2M3dRlYN6WBBOmdoOzJ6njQSepKdE= github.com/tkrajina/go-reflector v0.5.6/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4= +github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 h1:XQpsQG5lqRJlx4mUVHcJvyyc1rdTI9nHvwrdfcuy8aM= +github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117/go.mod h1:mx0TjbqsaDD9DUT5gA1s3hw47U6RIbbIBfvGzR85K0g= +github.com/ubuntu/gowsl v0.0.0-20230615094051-94945650cc1e h1:5hJ4Z9ISvbDUWL7TDvfoYp0bXsaX42WjAUJzyZ8NMCI= +github.com/ubuntu/gowsl v0.0.0-20230615094051-94945650cc1e/go.mod h1:tu2rOgQGt6bZce1OE8G75Ca8+NvNmTNOvplLolr326I= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= @@ -86,10 +97,12 @@ golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/main.go b/main.go index d237709..abdc8bd 100644 --- a/main.go +++ b/main.go @@ -26,12 +26,22 @@ var cyacInfo embed.FS //go:embed backend-python var py embed.FS +//go:embed finetune +var finetune embed.FS + func main() { if buildInfo, ok := debug.ReadBuildInfo(); !ok || strings.Contains(buildInfo.String(), "-ldflags") { backend.CopyEmbed(cyac) backend.CopyEmbed(cyacInfo) backend.CopyEmbed(py) + backend.CopyEmbed(finetune) os.Mkdir("models", os.ModePerm) + os.Mkdir("lora-models", os.ModePerm) + } + + f, err := os.Create("lora-models/train_log.txt") + if err == nil { + f.Close() } // Create an instance of the app structure