From f05a4acb04b7d2f7bbeb5c8f3fe2f08ce8758d1d Mon Sep 17 00:00:00 2001 From: josc146 Date: Tue, 28 May 2024 22:35:47 +0800 Subject: [PATCH] sync https://github.com/JL-er/RWKV-PEFT --- finetune/lora/v6/cuda/wkv6infctx_cuda.cu | 311 ++++ finetune/lora/v6/cuda/wkv6infctx_op.cpp | 22 + finetune/lora/v6/cuda/wkv6state_cuda.cu | 311 ++++ finetune/lora/v6/cuda/wkv6state_op.cpp | 22 + finetune/lora/v6/demo/demo-lora-merge.sh | 16 + finetune/lora/v6/demo/demo-lora.sh | 27 + finetune/lora/v6/demo/demo-pissa-merge.sh | 15 + finetune/lora/v6/demo/demo-pissa.sh | 40 + finetune/lora/v6/demo/demo-qpissa-pt.sh | 27 + finetune/lora/v6/demo/demo-state-merge.sh | 8 + finetune/lora/v6/demo/demo-state-tuning.sh | 22 + .../lora/v6/demo/demo-training-prepare.sh | 27 + finetune/lora/v6/demo/demo-training-run.sh | 21 + finetune/lora/v6/demo/demo.jsonl | 182 ++ finetune/lora/v6/demo/infctx.sh | 25 + finetune/lora/v6/fla/__init__.py | 50 + finetune/lora/v6/fla/layers/__init__.py | 25 + finetune/lora/v6/fla/layers/abc.py | 195 +++ finetune/lora/v6/fla/layers/based.py | 126 ++ finetune/lora/v6/fla/layers/delta_net.py | 254 +++ finetune/lora/v6/fla/layers/gated_abc.py | 234 +++ finetune/lora/v6/fla/layers/gla.py | 268 +++ finetune/lora/v6/fla/layers/hgrn.py | 165 ++ finetune/lora/v6/fla/layers/hgrn2.py | 186 ++ finetune/lora/v6/fla/layers/linear_attn.py | 156 ++ .../v6/fla/layers/multiscale_retention.py | 271 +++ finetune/lora/v6/fla/layers/rebased.py | 137 ++ finetune/lora/v6/fla/layers/rwkv6.py | 264 +++ finetune/lora/v6/fla/layers/simple_gla.py | 143 ++ finetune/lora/v6/fla/models/__init__.py | 29 + finetune/lora/v6/fla/models/abc/__init__.py | 13 + .../v6/fla/models/abc/configuration_abc.py | 74 + .../lora/v6/fla/models/abc/modeling_abc.py | 394 +++++ .../lora/v6/fla/models/delta_net/__init__.py | 14 + .../delta_net/configuration_delta_net.py | 77 + .../models/delta_net/modeling_delta_net.py | 405 +++++ finetune/lora/v6/fla/models/gla/__init__.py | 13 + .../v6/fla/models/gla/configuration_gla.py | 80 + .../lora/v6/fla/models/gla/modeling_gla.py | 403 +++++ finetune/lora/v6/fla/models/hgrn/__init__.py | 13 + .../v6/fla/models/hgrn/configuration_hgrn.py | 66 + .../lora/v6/fla/models/hgrn/modeling_hgrn.py | 407 +++++ finetune/lora/v6/fla/models/hgrn2/__init__.py | 13 + .../fla/models/hgrn2/configuration_hgrn2.py | 66 + .../v6/fla/models/hgrn2/modeling_hgrn2.py | 407 +++++ .../v6/fla/models/linear_attn/__init__.py | 14 + .../linear_attn/configuration_linear_attn.py | 70 + .../linear_attn/modeling_linear_attn.py | 424 +++++ finetune/lora/v6/fla/models/mamba/__init__.py | 14 + .../fla/models/mamba/configuration_mamba.py | 156 ++ .../v6/fla/models/mamba/modeling_mamba.py | 605 +++++++ .../lora/v6/fla/models/retnet/__init__.py | 13 + .../fla/models/retnet/configuration_retnet.py | 76 + .../v6/fla/models/retnet/modeling_retnet.py | 410 +++++ finetune/lora/v6/fla/models/rwkv6/__init__.py | 13 + .../fla/models/rwkv6/configuration_rwkv6.py | 66 + .../v6/fla/models/rwkv6/modeling_rwkv6.py | 443 +++++ .../v6/fla/models/transformer/__init__.py | 14 + .../transformer/configuration_transformer.py | 61 + .../transformer/modeling_transformer.py | 522 ++++++ finetune/lora/v6/fla/models/utils.py | 107 ++ finetune/lora/v6/fla/modules/__init__.py | 20 + finetune/lora/v6/fla/modules/activations.py | 394 +++++ finetune/lora/v6/fla/modules/convolution.py | 336 ++++ finetune/lora/v6/fla/modules/feature_map.py | 235 +++ .../v6/fla/modules/fused_cross_entropy.py | 398 +++++ .../lora/v6/fla/modules/fused_norm_gate.py | 889 ++++++++++ finetune/lora/v6/fla/modules/l2norm.py | 216 +++ finetune/lora/v6/fla/modules/layernorm.py | 802 +++++++++ finetune/lora/v6/fla/modules/rotary.py | 310 ++++ finetune/lora/v6/fla/ops/__init__.py | 18 + finetune/lora/v6/fla/ops/abc/__init__.py | 11 + finetune/lora/v6/fla/ops/abc/chunk.py | 1194 +++++++++++++ finetune/lora/v6/fla/ops/abc/chunk_gate.py | 1287 ++++++++++++++ finetune/lora/v6/fla/ops/abc/naive.py | 90 + .../lora/v6/fla/ops/abc/recurrent_fuse.py | 388 +++++ finetune/lora/v6/fla/ops/based/__init__.py | 9 + finetune/lora/v6/fla/ops/based/chunk_fuse.py | 410 +++++ finetune/lora/v6/fla/ops/based/naive.py | 132 ++ finetune/lora/v6/fla/ops/based/parallel.py | 388 +++++ finetune/lora/v6/fla/ops/delta_rule/README.md | 4 + .../lora/v6/fla/ops/delta_rule/__init__.py | 11 + finetune/lora/v6/fla/ops/delta_rule/chunk.py | 544 ++++++ .../lora/v6/fla/ops/delta_rule/chunk_fuse.py | 419 +++++ finetune/lora/v6/fla/ops/delta_rule/naive.py | 92 + .../v6/fla/ops/delta_rule/recurrent_fuse.py | 312 ++++ finetune/lora/v6/fla/ops/delta_rule/utils.py | 297 ++++ .../lora/v6/fla/ops/delta_rule/wy_fast.py | 401 +++++ finetune/lora/v6/fla/ops/gla/__init__.py | 11 + finetune/lora/v6/fla/ops/gla/chunk.py | 734 ++++++++ finetune/lora/v6/fla/ops/gla/chunk_fuse.py | 548 ++++++ finetune/lora/v6/fla/ops/gla/chunk_util.py | 138 ++ finetune/lora/v6/fla/ops/gla/naive.py | 116 ++ .../lora/v6/fla/ops/gla/recurrent_fuse.py | 404 +++++ finetune/lora/v6/fla/ops/hgrn/__init__.py | 9 + finetune/lora/v6/fla/ops/hgrn/chunk.py | 373 ++++ finetune/lora/v6/fla/ops/hgrn/naive.py | 31 + .../lora/v6/fla/ops/hgrn/recurrent_fuse.py | 185 ++ .../lora/v6/fla/ops/linear_attn/__init__.py | 12 + finetune/lora/v6/fla/ops/linear_attn/chunk.py | 359 ++++ .../lora/v6/fla/ops/linear_attn/chunk_fuse.py | 326 ++++ finetune/lora/v6/fla/ops/linear_attn/naive.py | 20 + .../v6/fla/ops/linear_attn/recurrent_fuse.py | 284 ++++ finetune/lora/v6/fla/ops/rebased/__init__.py | 7 + finetune/lora/v6/fla/ops/rebased/naive.py | 80 + finetune/lora/v6/fla/ops/rebased/parallel.py | 387 +++++ .../lora/v6/fla/ops/retention/__init__.py | 13 + finetune/lora/v6/fla/ops/retention/chunk.py | 364 ++++ .../lora/v6/fla/ops/retention/chunk_fuse.py | 334 ++++ finetune/lora/v6/fla/ops/retention/naive.py | 15 + .../lora/v6/fla/ops/retention/parallel.py | 339 ++++ .../v6/fla/ops/retention/recurrent_fuse.py | 281 +++ finetune/lora/v6/fla/ops/rotary.py | 252 +++ finetune/lora/v6/fla/ops/rwkv4/__init__.py | 7 + .../lora/v6/fla/ops/rwkv4/recurrent_fuse.py | 484 ++++++ finetune/lora/v6/fla/ops/rwkv6/__init__.py | 9 + finetune/lora/v6/fla/ops/rwkv6/chunk.py | 921 ++++++++++ finetune/lora/v6/fla/ops/rwkv6/chunk_naive.py | 79 + .../lora/v6/fla/ops/rwkv6/recurrent_fuse.py | 378 +++++ .../lora/v6/fla/ops/rwkv6/recurrent_naive.py | 102 ++ finetune/lora/v6/fla/ops/simple_gla/README.md | 5 + .../lora/v6/fla/ops/simple_gla/__init__.py | 8 + finetune/lora/v6/fla/ops/simple_gla/chunk.py | 415 +++++ finetune/lora/v6/fla/ops/simple_gla/naive.py | 52 + finetune/lora/v6/fla/ops/utils.py | 579 +++++++ finetune/lora/v6/fla/utils.py | 33 + finetune/lora/v6/merge/merge.py | 92 + finetune/lora/v6/merge/merge_lora.py | 52 + finetune/lora/v6/merge/merge_pissa.py | 58 + finetune/lora/v6/merge/merge_state.py | 36 + finetune/lora/v6/requirements.txt | 5 + finetune/lora/v6/src/binidx.py | 4 +- finetune/lora/v6/src/dataset.py | 8 +- finetune/lora/v6/src/infctx_module.py | 52 + finetune/lora/v6/src/model.py | 1508 +++++++++++++---- finetune/lora/v6/src/trainer.py | 43 +- finetune/lora/v6/train.py | 184 +- finetune/requirements.txt | 6 +- 138 files changed, 29047 insertions(+), 334 deletions(-) create mode 100644 finetune/lora/v6/cuda/wkv6infctx_cuda.cu create mode 100644 finetune/lora/v6/cuda/wkv6infctx_op.cpp create mode 100644 finetune/lora/v6/cuda/wkv6state_cuda.cu create mode 100644 finetune/lora/v6/cuda/wkv6state_op.cpp create mode 100644 finetune/lora/v6/demo/demo-lora-merge.sh create mode 100644 finetune/lora/v6/demo/demo-lora.sh create mode 100644 finetune/lora/v6/demo/demo-pissa-merge.sh create mode 100644 finetune/lora/v6/demo/demo-pissa.sh create mode 100644 finetune/lora/v6/demo/demo-qpissa-pt.sh create mode 100644 finetune/lora/v6/demo/demo-state-merge.sh create mode 100644 finetune/lora/v6/demo/demo-state-tuning.sh create mode 100644 finetune/lora/v6/demo/demo-training-prepare.sh create mode 100644 finetune/lora/v6/demo/demo-training-run.sh create mode 100644 finetune/lora/v6/demo/demo.jsonl create mode 100644 finetune/lora/v6/demo/infctx.sh create mode 100644 finetune/lora/v6/fla/__init__.py create mode 100644 finetune/lora/v6/fla/layers/__init__.py create mode 100644 finetune/lora/v6/fla/layers/abc.py create mode 100644 finetune/lora/v6/fla/layers/based.py create mode 100644 finetune/lora/v6/fla/layers/delta_net.py create mode 100644 finetune/lora/v6/fla/layers/gated_abc.py create mode 100644 finetune/lora/v6/fla/layers/gla.py create mode 100644 finetune/lora/v6/fla/layers/hgrn.py create mode 100644 finetune/lora/v6/fla/layers/hgrn2.py create mode 100644 finetune/lora/v6/fla/layers/linear_attn.py create mode 100644 finetune/lora/v6/fla/layers/multiscale_retention.py create mode 100644 finetune/lora/v6/fla/layers/rebased.py create mode 100644 finetune/lora/v6/fla/layers/rwkv6.py create mode 100644 finetune/lora/v6/fla/layers/simple_gla.py create mode 100644 finetune/lora/v6/fla/models/__init__.py create mode 100644 finetune/lora/v6/fla/models/abc/__init__.py create mode 100644 finetune/lora/v6/fla/models/abc/configuration_abc.py create mode 100644 finetune/lora/v6/fla/models/abc/modeling_abc.py create mode 100644 finetune/lora/v6/fla/models/delta_net/__init__.py create mode 100644 finetune/lora/v6/fla/models/delta_net/configuration_delta_net.py create mode 100644 finetune/lora/v6/fla/models/delta_net/modeling_delta_net.py create mode 100644 finetune/lora/v6/fla/models/gla/__init__.py create mode 100644 finetune/lora/v6/fla/models/gla/configuration_gla.py create mode 100644 finetune/lora/v6/fla/models/gla/modeling_gla.py create mode 100644 finetune/lora/v6/fla/models/hgrn/__init__.py create mode 100644 finetune/lora/v6/fla/models/hgrn/configuration_hgrn.py create mode 100644 finetune/lora/v6/fla/models/hgrn/modeling_hgrn.py create mode 100644 finetune/lora/v6/fla/models/hgrn2/__init__.py create mode 100644 finetune/lora/v6/fla/models/hgrn2/configuration_hgrn2.py create mode 100644 finetune/lora/v6/fla/models/hgrn2/modeling_hgrn2.py create mode 100644 finetune/lora/v6/fla/models/linear_attn/__init__.py create mode 100644 finetune/lora/v6/fla/models/linear_attn/configuration_linear_attn.py create mode 100644 finetune/lora/v6/fla/models/linear_attn/modeling_linear_attn.py create mode 100644 finetune/lora/v6/fla/models/mamba/__init__.py create mode 100644 finetune/lora/v6/fla/models/mamba/configuration_mamba.py create mode 100644 finetune/lora/v6/fla/models/mamba/modeling_mamba.py create mode 100644 finetune/lora/v6/fla/models/retnet/__init__.py create mode 100644 finetune/lora/v6/fla/models/retnet/configuration_retnet.py create mode 100644 finetune/lora/v6/fla/models/retnet/modeling_retnet.py create mode 100644 finetune/lora/v6/fla/models/rwkv6/__init__.py create mode 100644 finetune/lora/v6/fla/models/rwkv6/configuration_rwkv6.py create mode 100644 finetune/lora/v6/fla/models/rwkv6/modeling_rwkv6.py create mode 100644 finetune/lora/v6/fla/models/transformer/__init__.py create mode 100644 finetune/lora/v6/fla/models/transformer/configuration_transformer.py create mode 100644 finetune/lora/v6/fla/models/transformer/modeling_transformer.py create mode 100644 finetune/lora/v6/fla/models/utils.py create mode 100644 finetune/lora/v6/fla/modules/__init__.py create mode 100644 finetune/lora/v6/fla/modules/activations.py create mode 100644 finetune/lora/v6/fla/modules/convolution.py create mode 100644 finetune/lora/v6/fla/modules/feature_map.py create mode 100644 finetune/lora/v6/fla/modules/fused_cross_entropy.py create mode 100644 finetune/lora/v6/fla/modules/fused_norm_gate.py create mode 100644 finetune/lora/v6/fla/modules/l2norm.py create mode 100644 finetune/lora/v6/fla/modules/layernorm.py create mode 100644 finetune/lora/v6/fla/modules/rotary.py create mode 100644 finetune/lora/v6/fla/ops/__init__.py create mode 100644 finetune/lora/v6/fla/ops/abc/__init__.py create mode 100644 finetune/lora/v6/fla/ops/abc/chunk.py create mode 100644 finetune/lora/v6/fla/ops/abc/chunk_gate.py create mode 100644 finetune/lora/v6/fla/ops/abc/naive.py create mode 100644 finetune/lora/v6/fla/ops/abc/recurrent_fuse.py create mode 100644 finetune/lora/v6/fla/ops/based/__init__.py create mode 100644 finetune/lora/v6/fla/ops/based/chunk_fuse.py create mode 100644 finetune/lora/v6/fla/ops/based/naive.py create mode 100644 finetune/lora/v6/fla/ops/based/parallel.py create mode 100644 finetune/lora/v6/fla/ops/delta_rule/README.md create mode 100644 finetune/lora/v6/fla/ops/delta_rule/__init__.py create mode 100644 finetune/lora/v6/fla/ops/delta_rule/chunk.py create mode 100644 finetune/lora/v6/fla/ops/delta_rule/chunk_fuse.py create mode 100644 finetune/lora/v6/fla/ops/delta_rule/naive.py create mode 100644 finetune/lora/v6/fla/ops/delta_rule/recurrent_fuse.py create mode 100644 finetune/lora/v6/fla/ops/delta_rule/utils.py create mode 100644 finetune/lora/v6/fla/ops/delta_rule/wy_fast.py create mode 100644 finetune/lora/v6/fla/ops/gla/__init__.py create mode 100644 finetune/lora/v6/fla/ops/gla/chunk.py create mode 100644 finetune/lora/v6/fla/ops/gla/chunk_fuse.py create mode 100644 finetune/lora/v6/fla/ops/gla/chunk_util.py create mode 100644 finetune/lora/v6/fla/ops/gla/naive.py create mode 100644 finetune/lora/v6/fla/ops/gla/recurrent_fuse.py create mode 100644 finetune/lora/v6/fla/ops/hgrn/__init__.py create mode 100644 finetune/lora/v6/fla/ops/hgrn/chunk.py create mode 100644 finetune/lora/v6/fla/ops/hgrn/naive.py create mode 100644 finetune/lora/v6/fla/ops/hgrn/recurrent_fuse.py create mode 100644 finetune/lora/v6/fla/ops/linear_attn/__init__.py create mode 100644 finetune/lora/v6/fla/ops/linear_attn/chunk.py create mode 100644 finetune/lora/v6/fla/ops/linear_attn/chunk_fuse.py create mode 100644 finetune/lora/v6/fla/ops/linear_attn/naive.py create mode 100644 finetune/lora/v6/fla/ops/linear_attn/recurrent_fuse.py create mode 100644 finetune/lora/v6/fla/ops/rebased/__init__.py create mode 100644 finetune/lora/v6/fla/ops/rebased/naive.py create mode 100644 finetune/lora/v6/fla/ops/rebased/parallel.py create mode 100644 finetune/lora/v6/fla/ops/retention/__init__.py create mode 100644 finetune/lora/v6/fla/ops/retention/chunk.py create mode 100644 finetune/lora/v6/fla/ops/retention/chunk_fuse.py create mode 100644 finetune/lora/v6/fla/ops/retention/naive.py create mode 100644 finetune/lora/v6/fla/ops/retention/parallel.py create mode 100644 finetune/lora/v6/fla/ops/retention/recurrent_fuse.py create mode 100644 finetune/lora/v6/fla/ops/rotary.py create mode 100644 finetune/lora/v6/fla/ops/rwkv4/__init__.py create mode 100644 finetune/lora/v6/fla/ops/rwkv4/recurrent_fuse.py create mode 100644 finetune/lora/v6/fla/ops/rwkv6/__init__.py create mode 100644 finetune/lora/v6/fla/ops/rwkv6/chunk.py create mode 100644 finetune/lora/v6/fla/ops/rwkv6/chunk_naive.py create mode 100644 finetune/lora/v6/fla/ops/rwkv6/recurrent_fuse.py create mode 100644 finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py create mode 100644 finetune/lora/v6/fla/ops/simple_gla/README.md create mode 100644 finetune/lora/v6/fla/ops/simple_gla/__init__.py create mode 100644 finetune/lora/v6/fla/ops/simple_gla/chunk.py create mode 100644 finetune/lora/v6/fla/ops/simple_gla/naive.py create mode 100644 finetune/lora/v6/fla/ops/utils.py create mode 100644 finetune/lora/v6/fla/utils.py create mode 100644 finetune/lora/v6/merge/merge.py create mode 100644 finetune/lora/v6/merge/merge_lora.py create mode 100644 finetune/lora/v6/merge/merge_pissa.py create mode 100644 finetune/lora/v6/merge/merge_state.py create mode 100644 finetune/lora/v6/requirements.txt create mode 100644 finetune/lora/v6/src/infctx_module.py diff --git a/finetune/lora/v6/cuda/wkv6infctx_cuda.cu b/finetune/lora/v6/cuda/wkv6infctx_cuda.cu new file mode 100644 index 0000000..a3a19da --- /dev/null +++ b/finetune/lora/v6/cuda/wkv6infctx_cuda.cu @@ -0,0 +1,311 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, F *__restrict__ _s, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + _s += h*_N_*_N_ + i*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_]; + + __syncthreads(); + u[i] = float(_u[i]); + __syncthreads(); + for (int j = 0; j < _N_; j++) { + state[j] = float(_s[j]); + } + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + w[i] = __expf(-__expf(float(_w[t]))); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } + #pragma unroll + for (int j = 0; j < _N_; j++) + _s[j] = F(state[j]); +} + +template +__global__ void kernel_backward_111(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu, F *__restrict__ const _gs) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + _s += h*_N_*_N_ + i; + + __shared__ float u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; + __syncthreads(); + u_[i] = float(_u[i]); + __syncthreads(); + + const float u = u_[i]; + + float state[_N_], scccc[_N_] = {0}, sdddd[_N_] = {0}, sssss[_N_] = {0}, swwww[_N_]; + for (int j = 0; j < _N_; j++) { + state[j] = float(_s[j*_N_]); + swwww[j] = 1.0; + } + + const int t_0 = b*T*C + h*_N_ + i; + const int t_T_1 = t_0 + (T-1)*C; + const int t_T = t_0 + T*C; + + float gu = 0; + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + const float w = __expf(-__expf(float(_w[t]))); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + const float w = __expf(-__expf(float(_w[t]))); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + w_[i] = __expf(-__expf(float(_w[t]))); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } + + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + r[i] = float(_r[t]); + w_[i] = __expf(-__expf(float(_w[t]))); + __syncthreads(); + + const float gyy = float(_gy[t]); + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& w = swwww[j]; + sssss[j] += gyy * w * r[j]; + w *= w_[j]; + } + } + for (int j = 0; j < _N_; j++) + _gs[b*H*_N_*_N_ + h*_N_*_N_ + i*_N_ + j] = F(sssss[j]); +} + +template +__global__ void kernel_backward_222(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy, + F *__restrict__ const _gw) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _s += h*_N_*_N_ + i; + + __shared__ float v[_N_], gy[_N_]; + float state[_N_], saaaa[_N_] = {0}, sbbbb[_T_-1] = {0}, scccc[_N_] = {0}; + for (int j = 0; j < _N_; j++) { + state[j] = float(_s[j*_N_]); + } + + const int t_0 = b*T*C + h*_N_ + i; + const int t_1 = t_0 + C; + const int t_2 = t_0 + 2*C; + const int t_T_1 = t_0 + (T-1)*C; + + for (int t = t_T_1; t > t_1; t -= C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float r = float(_r[t]); + const float w = __expf(-__expf(float(_w[t-C]))); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + s = (s + r * gy[j]) * w; + sum += s * v[j]; + } + sbbbb[(t-t_1)/C] = sum * float(_k[t-2*C]); + } + { + __syncthreads(); + gy[i] = float(_gy[t_1]); + __syncthreads(); + + const float r = float(_r[t_1]); + const float w = __expf(-__expf(float(_w[t_0]))); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + s = (s + r * gy[j]) * w; + sum += s * state[j]; + } + sbbbb[0] = sum; + } + + float sss = sbbbb[0]; + _gw[t_0] = F(sss * -__expf(float(_w[t_0]))); + + { + __syncthreads(); + gy[i] = float(_gy[t_1]); + __syncthreads(); + + const float w = __expf(-__expf(float(_w[t_0]))); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + s = (s + state[j]) * w; + sum += s * gy[j]; + } + sss += sbbbb[1] - (sum * float(_r[t_1])); + _gw[t_1] = F(sss * -__expf(float(_w[t_1]))); + } + for (int t = t_2; t < t_T_1; t += C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float w = __expf(-__expf(float(_w[t-C]))); + const float k = float(_k[t-2*C]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + s = (s + k * v[j]) * w; + sum += s * gy[j]; + } + sss += sbbbb[(t-t_0)/C] - (sum * float(_r[t])); + _gw[t] = F(sss * -__expf(float(_w[t]))); + } + _gw[t_T_1] = 0; +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_forward<<>>(B, T, C, H, r, k, v, w, u, z, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, z, gy, gr, gk, gv, gu, gs); + kernel_backward_222<<>>(B, T, C, H, r, k, v, w, u, z, gy, gw); +} diff --git a/finetune/lora/v6/cuda/wkv6infctx_op.cpp b/finetune/lora/v6/cuda/wkv6infctx_op.cpp new file mode 100644 index 0000000..4df24ca --- /dev/null +++ b/finetune/lora/v6/cuda/wkv6infctx_op.cpp @@ -0,0 +1,22 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv6state forward"); + m.def("backward", &backward, "wkv6state backward"); +} + +TORCH_LIBRARY(wkv6state, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/finetune/lora/v6/cuda/wkv6state_cuda.cu b/finetune/lora/v6/cuda/wkv6state_cuda.cu new file mode 100644 index 0000000..3c28c2b --- /dev/null +++ b/finetune/lora/v6/cuda/wkv6state_cuda.cu @@ -0,0 +1,311 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u,const F *__restrict__ _s, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + _s += h*_N_*_N_ + i*_N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_]; + + __syncthreads(); + u[i] = float(_u[i]); + __syncthreads(); + for (int j = 0; j < _N_; j++) { + state[j] = float(_s[j]); + } + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + w[i] = __expf(-__expf(float(_w[t]))); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } + // #pragma unroll + // for (int j = 0; j < _N_; j++) + // _s[j] = F(state[j]); +} + +template +__global__ void kernel_backward_111(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy, + F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu, F *__restrict__ const _gs) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + _s += h*_N_*_N_ + i; + + __shared__ float u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; + __syncthreads(); + u_[i] = float(_u[i]); + __syncthreads(); + + const float u = u_[i]; + + float state[_N_], scccc[_N_] = {0}, sdddd[_N_] = {0}, sssss[_N_] = {0}, swwww[_N_]; + for (int j = 0; j < _N_; j++) { + state[j] = float(_s[j*_N_]); + swwww[j] = 1.0; + } + + const int t_0 = b*T*C + h*_N_ + i; + const int t_T_1 = t_0 + (T-1)*C; + const int t_T = t_0 + T*C; + + float gu = 0; + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + const float w = __expf(-__expf(float(_w[t]))); + float gr = 0, gu_ = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b*C + h*_N_ + i] = F(gu); + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + const float w = __expf(-__expf(float(_w[t]))); + float gk = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t_T_1; t >= t_0; t -= C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + w_[i] = __expf(-__expf(float(_w[t]))); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } + + for (int t = t_0; t < t_T; t += C) + { + __syncthreads(); + r[i] = float(_r[t]); + w_[i] = __expf(-__expf(float(_w[t]))); + __syncthreads(); + + const float gyy = float(_gy[t]); + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& w = swwww[j]; + sssss[j] += gyy * w * r[j]; + w *= w_[j]; + } + } + for (int j = 0; j < _N_; j++) + _gs[b*H*_N_*_N_ + h*_N_*_N_ + i*_N_ + j] = F(sssss[j]); +} + +template +__global__ void kernel_backward_222(const int B, const int T, const int C, const int H, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ _s, const F *__restrict__ const _gy, + F *__restrict__ const _gw) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _s += h*_N_*_N_ + i; + + __shared__ float v[_N_], gy[_N_]; + float state[_N_], saaaa[_N_] = {0}, sbbbb[_T_-1] = {0}, scccc[_N_] = {0}; + for (int j = 0; j < _N_; j++) { + state[j] = float(_s[j*_N_]); + } + + const int t_0 = b*T*C + h*_N_ + i; + const int t_1 = t_0 + C; + const int t_2 = t_0 + 2*C; + const int t_T_1 = t_0 + (T-1)*C; + + for (int t = t_T_1; t > t_1; t -= C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float r = float(_r[t]); + const float w = __expf(-__expf(float(_w[t-C]))); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + s = (s + r * gy[j]) * w; + sum += s * v[j]; + } + sbbbb[(t-t_1)/C] = sum * float(_k[t-2*C]); + } + { + __syncthreads(); + gy[i] = float(_gy[t_1]); + __syncthreads(); + + const float r = float(_r[t_1]); + const float w = __expf(-__expf(float(_w[t_0]))); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = saaaa[j]; + s = (s + r * gy[j]) * w; + sum += s * state[j]; + } + sbbbb[0] = sum; + } + + float sss = sbbbb[0]; + _gw[t_0] = F(sss * -__expf(float(_w[t_0]))); + + { + __syncthreads(); + gy[i] = float(_gy[t_1]); + __syncthreads(); + + const float w = __expf(-__expf(float(_w[t_0]))); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + s = (s + state[j]) * w; + sum += s * gy[j]; + } + sss += sbbbb[1] - (sum * float(_r[t_1])); + _gw[t_1] = F(sss * -__expf(float(_w[t_1]))); + } + for (int t = t_2; t < t_T_1; t += C) + { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t-2*C]); + __syncthreads(); + + const float w = __expf(-__expf(float(_w[t-C]))); + const float k = float(_k[t-2*C]); + float sum = 0.0f; + + #pragma unroll + for (int j = 0; j < _N_; j++) + { + float& s = scccc[j]; + s = (s + k * v[j]) * w; + sum += s * gy[j]; + } + sss += sbbbb[(t-t_0)/C] - (sum * float(_r[t])); + _gw[t] = F(sss * -__expf(float(_w[t]))); + } + _gw[t_T_1] = 0; +} + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *y) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_forward<<>>(B, T, C, H, r, k, v, w, u, z, y); +} + +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *z, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs) +{ + assert(H*_N_ == C); + assert(_N_%4 == 0); + kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, z, gy, gr, gk, gv, gu, gs); + kernel_backward_222<<>>(B, T, C, H, r, k, v, w, u, z, gy, gw); +} diff --git a/finetune/lora/v6/cuda/wkv6state_op.cpp b/finetune/lora/v6/cuda/wkv6state_op.cpp new file mode 100644 index 0000000..4df24ca --- /dev/null +++ b/finetune/lora/v6/cuda/wkv6state_op.cpp @@ -0,0 +1,22 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y); +void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); + +void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { + cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { + cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv6state forward"); + m.def("backward", &backward, "wkv6state backward"); +} + +TORCH_LIBRARY(wkv6state, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/finetune/lora/v6/demo/demo-lora-merge.sh b/finetune/lora/v6/demo/demo-lora-merge.sh new file mode 100644 index 0000000..5d97422 --- /dev/null +++ b/finetune/lora/v6/demo/demo-lora-merge.sh @@ -0,0 +1,16 @@ + +base_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth' +lora_init='/home/rwkv/JL/out_model/nf4/init_lora.pth' +lora_checkpoint='/home/rwkv/JL/out_model/nf4/rwkv-0.pth' +output='/home/rwkv/JL/model/nf4-world.pth' +QUANT='nf4' #follow train +TYPE='lora' +Lora_alpha=128 + +python merge/merge.py --base_model $base_model \ +--lora_init $lora_init \ +--lora_checkpoint $lora_checkpoint \ +--output $output \ +--quant $QUANT \ +--type $TYPE \ +--lora_alpha $Lora_alpha \ No newline at end of file diff --git a/finetune/lora/v6/demo/demo-lora.sh b/finetune/lora/v6/demo/demo-lora.sh new file mode 100644 index 0000000..f244cc3 --- /dev/null +++ b/finetune/lora/v6/demo/demo-lora.sh @@ -0,0 +1,27 @@ +load_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth' +proj_dir='/home/rwkv/JL/out_model/nf4' +data_file='/home/rwkv/JL/data/roleplay' + +QUANT='nf4' #4bit nf4 fp4 none + +lora_r=64 +lora_alpha=128 + +n_layer=32 +n_embd=4096 + +micro_bsz=8 +epoch_save=1 +epoch_steps=1000 +ctx_len=1024 + +python train.py --load_model $load_model \ +--proj_dir $proj_dir --data_file $data_file \ +--data_type binidx --vocab_size 65536 \ +--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ +--n_layer $n_layer --n_embd $n_embd \ +--pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ +--my_testing "x060" \ +--lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha $lora_alpha --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \ +--quant $QUANT \ No newline at end of file diff --git a/finetune/lora/v6/demo/demo-pissa-merge.sh b/finetune/lora/v6/demo/demo-pissa-merge.sh new file mode 100644 index 0000000..73645ee --- /dev/null +++ b/finetune/lora/v6/demo/demo-pissa-merge.sh @@ -0,0 +1,15 @@ + + +base_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2-20240208-ctx4096.pth' +lora_init='/home/rwkv/JL/out_model/nf4/init_lora.pth' +lora_checkpoint='/home/rwkv/JL/out_model/nf4/rwkv-0.pth' +output='/home/rwkv/JL/model/end-world.pth' +QUANT='nf4' #follow train +TYPE='pissa' + +python merge/merge.py --base_model $base_model \ +--lora_init $lora_init \ +--lora_checkpoint $lora_checkpoint \ +--output $output \ +--quant $QUANT \ +--type $TYPE \ No newline at end of file diff --git a/finetune/lora/v6/demo/demo-pissa.sh b/finetune/lora/v6/demo/demo-pissa.sh new file mode 100644 index 0000000..fc5d183 --- /dev/null +++ b/finetune/lora/v6/demo/demo-pissa.sh @@ -0,0 +1,40 @@ + +load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth' +proj_dir='/home/rwkv/JL/out_model/nf4' +data_file='/home/rwkv/JL/data/end_text_document' + +QUANT='nf4' #4bit nf4 fp4 none +svd_niter=4 +lora_r=64 + +n_layer=24 +n_embd=2048 + +micro_bsz=8 +epoch_save=1 +epoch_steps=1000 +ctx_len=1024 + +python train.py --load_model $load_model \ +--proj_dir $proj_dir --data_file $data_file \ +--data_type binidx --vocab_size 65536 \ +--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ +--n_layer $n_layer --n_embd $n_embd \ +--pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ +--my_testing "x060" \ +--lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \ +--PISSA --svd_niter $svd_niter \ +--dataload pad + +###remove load_model +# python train.py --proj_dir $proj_dir --data_file $data_file \ +# --data_type binidx --vocab_size 65536 \ +# --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ +# --n_layer $n_layer --n_embd $n_embd \ +# --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ +# --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ +# --my_testing "x060" \ +# --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \ +# --PISSA --svd_niter $svd_niter \ +# --quant $QUANT \ No newline at end of file diff --git a/finetune/lora/v6/demo/demo-qpissa-pt.sh b/finetune/lora/v6/demo/demo-qpissa-pt.sh new file mode 100644 index 0000000..fd4b952 --- /dev/null +++ b/finetune/lora/v6/demo/demo-qpissa-pt.sh @@ -0,0 +1,27 @@ +load_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth' +proj_dir='/home/rwkv/JL/out_model/nf4' +data_file='/home/rwkv/JL/data/roleplay' + +QUANT='nf4' #4bit nf4 fp4 none +svd_niter=4 +lora_r=64 + +n_layer=32 +n_embd=4096 + +micro_bsz=4 +epoch_save=1 +epoch_steps=1000 +ctx_len=1024 + + +python train.py --proj_dir $proj_dir --data_file $data_file \ +--data_type binidx --vocab_size 65536 \ +--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ +--n_layer $n_layer --n_embd $n_embd \ +--pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ +--my_testing "x060" \ +--lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \ +--PISSA --svd_niter $svd_niter \ +--quant $QUANT \ No newline at end of file diff --git a/finetune/lora/v6/demo/demo-state-merge.sh b/finetune/lora/v6/demo/demo-state-merge.sh new file mode 100644 index 0000000..7a79110 --- /dev/null +++ b/finetune/lora/v6/demo/demo-state-merge.sh @@ -0,0 +1,8 @@ +base_model='/home/rwkv/JL/model/RWKV-x060-World-3B-v2.1-20240417-ctx4096.pth' +state_checkpoint='/home/rwkv/JL/out_model/state/rwkv-9.pth' +output='/home/rwkv/JL/model/state-0.pth' + + +python merge/merge_state.py --base_model $base_model \ +--state_checkpoint $state_checkpoint \ +--output $output \ No newline at end of file diff --git a/finetune/lora/v6/demo/demo-state-tuning.sh b/finetune/lora/v6/demo/demo-state-tuning.sh new file mode 100644 index 0000000..238a0cd --- /dev/null +++ b/finetune/lora/v6/demo/demo-state-tuning.sh @@ -0,0 +1,22 @@ +load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth' +proj_dir='/home/rwkv/JL/out_model/state' +data_file='/home/rwkv/JL/data/end_text_document' + + +n_layer=24 +n_embd=2048 + +micro_bsz=1 +epoch_save=1 +epoch_steps=1000 +ctx_len=1024 + +python train.py --load_model $load_model \ +--proj_dir $proj_dir --data_file $data_file \ +--data_type binidx --vocab_size 65536 \ +--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ +--n_layer $n_layer --n_embd $n_embd \ +--pre_ffn 0 --head_qk 0 --lr_init 1 --lr_final 1e-1 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 0 \ +--my_testing "x060" \ +--train_type "state" --dataload pad --wandb fla --fla \ No newline at end of file diff --git a/finetune/lora/v6/demo/demo-training-prepare.sh b/finetune/lora/v6/demo/demo-training-prepare.sh new file mode 100644 index 0000000..5615dac --- /dev/null +++ b/finetune/lora/v6/demo/demo-training-prepare.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Create data directory + +mkdir -p data + +# Download minipile (1498226207 tokens, around 3GB) + +wget --continue -O data/minipile.idx https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.idx +wget --continue -O data/minipile.bin https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.bin + +# Generate initial model (L12-D768 = 169M) + +BASE_NAME="model/0.1-1" +N_LAYER="12" +N_EMBD="768" + +# magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case) +# use https://www.dcode.fr/prime-numbers-search + +python train.py --wandb "" --proj_dir $BASE_NAME \ + --data_file "data/minipile" --data_type "binidx" --vocab_size 65536 \ + --ctx_len 512 --my_pile_stage 1 --epoch_count 1 --epoch_begin 0 \ + --epoch_save 1 --weight_decay 0 --head_size_a 64 \ + --num_nodes 1 --micro_bsz 1 --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 --my_exit_tokens 1498226207 --magic_prime 2926181 \ + --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 \ + --accelerator cpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 --enable_progress_bar False --ds_bucket_mb 200 diff --git a/finetune/lora/v6/demo/demo-training-run.sh b/finetune/lora/v6/demo/demo-training-run.sh new file mode 100644 index 0000000..14cd598 --- /dev/null +++ b/finetune/lora/v6/demo/demo-training-run.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +BASE_NAME="model/0.1-1" +N_LAYER="12" +N_EMBD="768" +M_BSZ="16" # takes 16G VRAM (reduce this to save VRAM) +LR_INIT="6e-4" +LR_FINAL="6e-5" +GRAD_CP=0 # set to 1 to save VRAM (will be slower) +EPOCH_SAVE=10 + +# magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case) +# use https://www.dcode.fr/prime-numbers-search + +python train.py --load_model "0" --wandb "RWKV-5-Test" --proj_dir $BASE_NAME \ + --ctx_len 512 --my_pile_stage 3 --epoch_count 999999 --epoch_begin 0 \ + --data_file "data/minipile" --my_exit_tokens 1498226207 --magic_prime 2926181 \ + --num_nodes 1 --micro_bsz $M_BSZ --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 \ + --lr_init $LR_INIT --lr_final $LR_FINAL --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 --data_type "binidx" --vocab_size 65536 \ + --weight_decay 0.001 --epoch_save $EPOCH_SAVE --head_size_a 64 \ + --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp $GRAD_CP --enable_progress_bar True --ds_bucket_mb 200 diff --git a/finetune/lora/v6/demo/demo.jsonl b/finetune/lora/v6/demo/demo.jsonl new file mode 100644 index 0000000..7d50e92 --- /dev/null +++ b/finetune/lora/v6/demo/demo.jsonl @@ -0,0 +1,182 @@ +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Write an article based on this \"A man has been charged with murder and attempted murder after a woman and the man she was on a date with were stabbed at a restaurant in Sydney, Australia.\"\n\nAssistant: Title: Tragedy Strikes in Sydney: Victims Stabbed during Dinner Date, Suspect Charged with Murder and Attempted Murder\nIntroduction\nA seemingly ordinary evening took a tragic turn in Sydney, Australia when a romantic dinner date at a local restaurant ended in bloodshed, as a man allegedly attacked a couple, stabbing both victims. Local authorities have charged the suspect with murder and attempted murder in connection with the horrific crime. In this article, we delve into the chilling incident, the suspect, and the impact it has had on the community.\nThe Incident\nThe alarming event took place at a popular restaurant in Sydney, Australia. A woman and her date were enjoying their meal – blissfully unaware of the tragic circumstances that would soon unfold. As they shared conversation and indulged in their dinner, a man approached the couple and, without warning, began to stab them with a sharp object.\nPatrons at the restaurant scrambled for safety, as the assailant continued his violent rampage. Bystanders quickly called emergency services, who promptly arrived at the scene. Although both victims were rushed to a nearby hospital, the woman succumbed to her injuries. The man she was on a date with remained in critical condition, fighting for his life.\nThe Suspect\nInitial investigations revealed that the alleged attacker had no known connections to the victims – adding to the mystifying nature of this sudden and brutal assault. Witnesses reported that the assailant seemed to have no specific motive and appeared to carry out the act senselessly.\nFollowing a thorough investigation, local police identified and arrested the suspect. During the process, it was discovered that the alleged attacker had a history of criminal behavior and a troubled past, though it is unclear if this played a role in the tragic incident.\nAuthorities have formally charged the man with murder and attempted murder in connection with the heinous crime. He awaits a hearing to determine a trial date and, if convicted, could face a life sentence in prison.\nThe Community's Response\nThe shocking nature of the crime has left the residents of Sydney reeling, as they struggle to come to terms with the harrowing event. The restaurant where the attack occurred has since been closed, with a makeshift memorial being created outside to commemorate the victims.\nMany have questioned how such a vicious crime could happen in what is considered to be one of the safest cities in the world. This tragic event has spurred local officials to reassess current security measures and devise strategies to reinforce public safety. Additionally, discussions surrounding mental health and criminal rehabilitation have surfaced as residents seek to comprehend the actions of the alleged perpetrator and prevent similar incidents from occurring in the future.\nIn the wake of the stabbing, the community has banded together with an outpouring of grief and support for the victims and their families. Candlelight vigils have been held, and an online fundraising campaign is underway to assist the surviving victim with his medical expenses and recovery.\nConclusion\nThe tragic attack in Sydney serves as a chilling reminder that senseless acts of violence can happen anywhere and at any time. The community's response to this horrific and seemingly random act of brutality has been one of solidarity and determination to prevent such incidents in the future. As the case unfolds, the victims and their families remain in the hearts of the community, who are grieving the devastating loss of a life cut tragically short and supporting the recovering victim as he continues to endure this unimaginable ordeal."} +{"text": "User: Q:I read this background article the other day: Water stratification is when water masses with different properties - salinity (halocline), oxygenation (chemocline), density (pycnocline), temperature (thermocline) - form layers that act as barriers to water mixing which could lead to anoxia or euxinia.[1] These layers are normally arranged according to density, with the least dense water masses sitting above the more dense layers. Water stratification also creates barriers to nutrient mixing between layers. This can affect the primary production in an area by limiting photosynthetic processes. When nutrients from the benthos cannot travel up into the photic zone, phytoplankton may be limited by nutrient availability. Lower primary production also leads to lower net productivity in waters.[2] I am facing a new situation today: High lake had a lot of fresh and salt water mixing, therefore presented layers of stratified water.The water was poor in nutrients.Low lake had no such problems and it was a very abundant and prolific fishing spot. Using the knowledge I acquired from the background article, how should I answer correctly the following question regarding my new situation: Which lake had fewer dissolved nutrients?\nA:\n\nAssistant: The lake with fewer dissolved nutrients would be High lake. This is because it had layers of stratified water, which acts as a barrier to nutrient mixing. As a result, it was poor in nutrients."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Summarize this article in one sentence.\nYou'll find this in your Start menu. You can also press ⊞ Win and type \"xbox.\" You can use the Xbox app to take a screenshot of any game running in Windows 10. You'll find this in the menu on the left side of the screen. This will display the DVR and screenshot settings. The Game DVR will need to be enabled in order to take screenshots with the app. Click the slider to toggle it on. When the Game DVR is enabled, the built-in shortcut is ⊞ Win+Alt+PrtScn. You can click the empty field next to this and set your own custom shortcut if you want. This will open the folder that your screenshots will be saved in. If you want to change where your screenshots are saved, you'll need to move the Captures folder to that location. Screenshots will automatically save to the Captures folder, wherever it is. For example, to move it to your Pictures folder, move up one directory and then drag the Captures folder to the Pictures folder in your sidebar. The Xbox app does not need to be open to use the screenshot feature. It will be saved to the Captures folder that you may have moved earlier.\n\nAssistant: The article explains how to use the Xbox app in Windows 10 to take screenshots of games by enabling the Game DVR and using the built-in shortcut or a custom shortcut, with the screenshots saved in the Captures folder which can be moved to a different location if desired."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: what brought him into the field was the idea of louis de broglie that particles have a wave aspect. schrödinger sought our de broglie's work and read it avidly.\nRepeat this setence, but with the correct capitalization.\n\nAssistant: What brought him into the field was the idea of Louis de Broglie that particles have a wave aspect. Schrödinger sought out de Broglie's work and read it avidly.\n1. I capitalized the first letter of the first word of the sentence, \"What.\"\n2. I capitalized the first letter of the proper name \"Louis\" and \"Broglie\" (last name).\n3. I capitalized the first letter of the proper name \"Schrödinger.\"\n4. I corrected the typo \"our\" to \"out.\""} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Add spaces: ifmelectronicgmbh•Friedrichstraße1•45128Essen—Wereservetherighttomaketechnicalalterationswithoutpriornotice.—GB—SV5200—03.06.2015\n\nAssistant: The task provided is to add spaces to a long, uninterrupted string of text to make it more readable. While determining where to place spaces, I'll be guided by the rules of the English language, as well as presumptions about the format of addresses, company names, dates, and technical/legal notes.\n1. Breaking down company details:\nFirst, I'm going to separate individual words and details related to the company and address. \n- \"ifmelectronicgmbh\": This appears to be a company name, which should be \"ifm electronic gmbh\".\n- \"•Friedrichstraße1•45128Essen\": This seems to be a street address in Germany (Friedrichstraße 1, 45128 Essen).\n- The bullet character (•) can be considered as a separator.\nAfter making these changes, the phrase so far is: \"ifm electronic gmbh • Friedrichstraße 1 • 45128 Essen\".\n2. Handling reservation clause:\nNext, I notice the segment of text starting with \"Wereservetherighttomaketechnicalalterationswithoutpriornotice.\" This appears to be a standard non-liability clause. Let's break it down:\n- \"Wereservetherighttomaketechnicalalterationswithoutpriornotice\": This can be separated into \"We reserve the right to make technical alterations without prior notice.\"\nAfter adding these modifications, the sentence so far becomes \"ifm electronic gmbh • Friedrichstraße 1 • 45128 Essen - We reserve the right to make technical alterations without prior notice.\"\n3. Identifying product details and date:\nTo wrap up, I'll divide the final details which seem to be related to a product model number and the date:\n- \"—GB—SV5200—03.06.2015\": Here, I can establish that \"GB\" and \"SV5200\" are likely separate entities (looks like product identifiers or codes), and \"03.06.2015\" is probably a date.\nTherefore, this segment becomes: \"- GB - SV5200 - 03.06.2015.\"\nSo, the final version of the separated text is: \"ifm electronic gmbh • Friedrichstraße 1 • 45128 Essen - We reserve the right to make technical alterations without prior notice. - GB - SV5200 - 03.06.2015.\"\nThis step-wise breakdown splits the original text into more readable sentences and phrases while maintaining their meanings and serving the reader's clarity."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: What is the sentiment of the following movie review sentence?\na christmas perennial\n\nAssistant: A \"Christmas perennial\" means a movie that is enjoyed every Christmas and comes back like a tradition. So, the sentiment of this movie review sentence is positive because it suggests that people like to watch this movie again and again during Christmas time."} +{"text": "User: What type of details about the tenors can be gathered from the following bio? Bio: the tenors -lrb- formerly known as the canadian tenors -rrb- are a vocal quartet consisting of remigio pereira , victor micallef , fraser walters , and clifton murray . they perform operatic pop music that is a mixture of classical and pop , featuring songs such as `` the prayer '' and `` panis angelicus '' from the former genre , and leonard cohen 's `` hallelujah '' from the latter . originating from canada , pereira from ottawa/gatineau , micallef from toronto , walters from vancouver and murray from port mcneil , the tenors have performed on more than 60 international television programs . they appeared on the oprah winfrey show with celine dion in 2010 , at the 63rd primetime emmy awards , at the opening ceremonies of 2010 winter olympics in vancouver and on itv -lrb- tv network -rrb- diamond jubilee show at windsor castle for hm queen elizabeth ii . the venues they have performed in include the tel aviv opera house , and the mandalay bay events center in las vegas . they also appeared on the 2009 grey cup broadcast and on cbc television 's `` holiday festival of ice '' . they have shared the stage with sarah mclachlan , neil young , paul anka , justin bieber , paul mccartney , brian mcknight , david foster , and natalie cole . universal music group released their platinum-selling self-titled debut album and double-platinum holiday album , the perfect gift . the tenors filmed a pbs special at the smith center for the performing arts that began airing in 2012 . the canadian tenors were renamed the tenors in november 2012 in advance of their third album lead with your heart . also in 2012 , the tenors signed a us distribution deal with verve records . the tenors latest release lead with your heart went platinum in canada in seven weeks . the album garnered a 2013 juno award for adult contemporary album of the year and a juno award nomination for jack richardson for producer of the year for bob ezrin 's production of the title track `` forever young . '' lead with your heart was featured on katie , good...\nThe answer to this question is:\n\nAssistant: From the bio, we can gather that The Tenors:\n1. Were formerly known as The Canadian Tenors.\n2. Consist of Remigio Pereira, Victor Micallef, Fraser Walters, and Clifton Murray.\n3. Perform operatic pop music, a mixture of classical and pop.\n4. Have songs such as \"The Prayer,\" \"Panis Angelicus,\" and Leonard Cohen's \"Hallelujah\" in their repertoire.\n5. Originated from Canada, with members from Ottawa/Gatineau, Toronto, Vancouver, and Port McNeil.\n6. Have performed on more than 60 international television programs, including The Oprah Winfrey Show, the 63rd Primetime Emmy Awards, and the 2010 Winter Olympics opening ceremonies.\n7. Performed at venues such as the Tel Aviv Opera House and Mandalay Bay Events Center in Las Vegas.\n8. Have shared the stage with artists like Sarah McLachlan, Neil Young, Paul Anka, Justin Bieber, Paul McCartney, Brian McKnight, David Foster, and Natalie Cole.\n9. Released platinum-selling self-titled debut album, a double-platinum holiday album called The Perfect Gift, and their third album, Lead With Your Heart.\n10. Filmed a PBS special at the Smith Center for the Performing Arts.\n11. Renamed themselves from The Canadian Tenors to The Tenors in November 2012.\n12. Signed a US distribution deal with Verve Records in 2012.\n13. Received a 2013 Juno Award for Adult Contemporary Album of the Year and a Juno Award nomination for Jack Richardson Producer of the Year for their album Lead With Your Heart."} +{"text": "System: You should describe the task and explain your answer. While answering a multiple choice question, first output the correct answer(s). Then explain why other answers are wrong. Think like you are answering to a five year old.\n\nUser: What is the most logical next event?\nHow to unblock copyright infringement on youtube \nUnderstand why you may have received a content id claim. \nContent id is a system that identifies potential copyrighted content in videos by scanning uploaded videos for previously-uploaded content. The system will scan for audio, video, and images.\nChoose your answer from:\na). Answer as many questions as you can. If you are currently having trouble out of confusion or have not answered a question assigned by your user, then this is a great way to check for copyright infringement.\nb). If a match occurs, the original owner is notified and a content id claim is filed. The original owner can choose to do nothing, mute their audio in your video, block the video from being viewed, monetize the video, or track the video's viewership.\nc). Figure out whether you have used the content id app. Copyright-infringement lawsuits can take up to a few years to fully prove that you use the content id app in your videos.\nd). If you have uploaded video videos to youtube, it will scan for audio, video, and images in your channel. You may receive your content id when you sign up for a video profile.\n\nAssistant: b). If a match occurs, the original owner is notified and a content id claim is filed. The original owner can choose to do nothing, mute their audio in your video, block the video from being viewed, monetize the video, or track the video's viewership.\nThis is the correct answer because it describes the next logical event in the process of handling a content id claim on YouTube. When a match is found, the original owner gets to decide what to do with the claimed content.\na) is wrong because it talks about answering questions and checking for copyright infringement, which is not directly related to the next step in unblocking copyright infringement on YouTube.\nc) is incorrect because it discusses proving the use of the content id app in lawsuits, which is not the next step in dealing with a content id claim on YouTube.\nd) is wrong because it talks about uploading videos to YouTube and receiving content id, but not about the next step in resolving a content id claim."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: OPTIONS:\n[i] Yes;\n[ii] No;\nGiven those answer options, answer the question: Title: Good book with a bad description of the product Review: Although this looks to be a great book, I was disappointed to see the giant inscription on the first page of the book. Because this book was for a gift, noting this inscription on the product description would have been very helpful Is this product review negative?\nA:\n\nAssistant: Yes, this product review is negative. The reviewer expresses disappointment with the product due to the giant inscription on the first page, which was not mentioned in the product description. They also mention that the book was intended as a gift, which further explains their dissatisfaction with the unexpected inscription. Although the reviewer acknowledges that the book might be great, their overall experience with the product seems to be negative due to this issue."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Write some highlights for the following article:\nThe historic six-month agreement over Tehran's nuclear program may begin a new era of relations with Iran, but it will be a long road back for the country's most vital sector, oil. Iran produces about two and half million barrels a day -- far off its 4-million-barrel-per-day peak a decade ago. Output is hovering at a level last seen at the end of Iran's war with Iraq. With North Sea Brent crude averaging over $100 a barrel for a record three years running, the sanctions on energy alone are costing Tehran about $50 billion in lost annual revenue. Despite the deal breakthrough, U.S. Secretary of State John Kerry said most of the sanctions will stick as the world gauges the intentions of this relatively new administration in Tehran. \"During the six month phase, the oil sanctions that will remain in place will continue to cause over $25 billion in lost revenues to Iran or over $4 billion a month,\" he said. The message is clear: the pressure remains, but if all goes well, in a half year's time Iran can expect more in return for transparency. The last few years have been filled with uncertainty. Tensions around the Strait of Hormuz with the on-and-off threats by the previous government of Mahmoud Ahmadinejad to block oil shipments in the Gulf have kept what strategists call a 10-15% risk premium on global energy prices. Iran sits on about 9% of the world's proven oil reserves, claiming a few years back that it has nearly 150 billion barrels and the world's largest gas field. But its top four customers --China, India, Japan and South Korea -- have all had to cut back their energy imports by a third or more in the past few years due to U.S. and European pressure. With every year that has passed, the screws have been tightened by Washington and the countries of the European Union. It was not only sanctions against oil, but also blocking Iran's ability to secure shipping insurance and to trade in U.S. dollars and euros. That economic isolation, many Middle East strategists I have spoken with suggest, is what brought Iran's new government to the negotiating table. In 2012, the Iranian rial plunged by up to 80%. Basic staples of life, poultry and bread skyrocketed in price and the economy went through a period of hyper-inflation. With rising import prices due to a plummeting currency, industrial Iran is no longer able to compete. There are reports that the state manufacturing sector had laid off up to 800,000 workers in 2012 and those who have kept their jobs saw their wages eroded by skyrocketing prices. Oil executives with experience in the country say this initial agreement would help lift a cloud of uncertainty over the oil market, but that President Hassan Rouhani and his cabinet have to establish much better contract terms if this honeymoon period lifts sanctions. The blunt-speaking Chief Executive of French energy group Total, Christophe de Margerie, told me at an energy conference in Abu Dhabi this month that Iran needs to try and create a better climate for investment if this weekend's breakthrough is sustained. \"I can tell you with the experience we have from Iran it doesn't always bring as we say in French partie de plaisir ... a win-win.\" Oil giant Saudi Arabia has expressed doubts about signing a deal with Tehran, which will introduce more challenges within OPEC. Iraq plans to double production by 2020 to six million barrels a day and with Iran wanting to rebuild exports, the Kingdom may have trim its own production to defend prices. It is still early days, but this country of nearly 80 million people has been described as potentially being the Germany of the Middle East with plenty of natural resources -- that is, if it can emerge from years of economic isolation.\nHighlights:\n\nAssistant: 1. Historic six-month agreement over Tehran's nuclear program presents an opportunity for Iran's oil sector recovery, but the journey will be long and difficult.\n2. Iran's oil production is currently around 2.5 million barrels per day, significantly less than its peak of 4 million barrels per day a decade ago.\n3. Sanctions on energy alone are costing Tehran approximately $50 billion in lost annual revenue.\n4. U.S. Secretary of State John Kerry said most sanctions will remain in place during the six-month phase, causing Iran to lose over $25 billion in revenues.\n5. Iran's top four oil customers, China, India, Japan, and South Korea, have had to cut back their imports by a third or more due to U.S. and European pressure.\n6. Iran's economy suffered greatly during the period of sanctions, with rampant inflation and industrial layoffs.\n7. Establishing better contract terms is crucial for Iran to lift sanctions and improve the investment climate.\n8. Oil giant Saudi Arabia has doubts about signing a deal with Tehran, introducing challenges within OPEC for Iran's oil output growth."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Given the question: Read the bio below and try to give details on aidan somers 's: - sport - county - icposition - province - icupdate - birth date - allstars - clprovince - code - name - birth place - icprovince - nhl - counties - icyears - irish - icallireland - clcounty - club Bio: aidan somers -lrb- born 1946 -rrb- is an irish retired hurler who played as a right corner-back for the wexford senior team . born in rathnure , county wexford , somers first excelled at hurling in his youth . he arrived on the inter-county scene at the age of seventeen when he first linked up with the wexford minor team , before later joining the under-21 side . he joined the senior panel during the 1970 championship and played for just one season . somers won one leinster medal on the field of play . at club level somers is a three-time leinster medallist with rathnure . in addition to this he also won eight championship medals .\nThe answer is:\n\nAssistant: Aidan Somers is a retired Irish hurler who has had an accomplished career in the sport. Below are the details based on the provided bio:\n- Sport: Hurling\n- County: Wexford\n- ICPosition: Right corner-back\n- Province: Leinster (Ireland)\n- ICUpdate: Somers joined the senior panel during the 1970 championship and played for just one season.\n- Birth Date: 1946\n- AllStars: Not mentioned in the bio\n- CLProvince: Leinster (won three Leinster medals with Rathnure club)\n- Code: Not mentioned in the bio\n- Name: Aidan Somers\n- Birth Place: Rathnure, County Wexford, Ireland\n- ICProvince: Leinster (won one Leinster medal on the field of play)\n- NHL: Not mentioned in the bio\n- Counties: Wexford\n- ICYears: Started at the age of seventeen with Wexford minor team and later joined the under-21 side before joining the senior panel in the 1970 championship (played for just one season)\n- Irish: Yes, Aidan Somers is of Irish nationality\n- ICAllIreland: Not mentioned in the bio\n- CLCounty: Wexford (won eight championship medals with Rathnure club)\n- Club: Rathnure\nAidan Somers was born in 1946 in Rathnure, County Wexford. He initially demonstrated his hurling talent while he was still young. He joined the inter-county scene at the age of seventeen, starting with the Wexford minor team before later progressing to the under-21 squad. Somers eventually became a member of the senior panel during the 1970 championship, where he played for only one season as a right corner-back.\nThroughout his career, Somers achieved notable successes both at the county and club levels. With the Wexford senior team, he won one Leinster medal on the field. As a member of the Rathnure club, he secured three Leinster medals and eight championship medals, demonstrating his exceptional skills and dedication to the sport of hurling."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: No 2328/91 (1), as last amended by Regulation (EC) No 3669/93 (2)\"> ID=\"1\">- Dried farmyard manure and dehydrated poultry manure> ID=\"2\">Need recognized by the inspection body or inspection authority\"> ID=\"2\">Indication of animal species\"> ID=\"2\">Coming from extensive husbandry and only in the sense of Article 6 (4) of Regulation (EEC) No 2328/91\"> ID=\"1\">- Composted animal excrements, including poultry manure and composted farmyard manure included> ID=\"2\">Need recognized by the inspection body or inspection authority\"> ID=\"2\">Indication of the animal species\">\nTranslate to Czech\nCzech:\n\nAssistant: Článek č. 2328/91 (1), naposledy změněn nařízením (ES) č. 3669/93 (2)\"> ID=\"1\">- Sušený hnoj ze dvora a dehydratovaný drůbeží trus> ID=\"2\">Potřeba uznána inspekčním orgánem nebo inspekčním úřadem\"> ID=\"2\">Označení druhu zvířat\"> ID=\"2\">Pocházející z extenzivního chovu a pouze ve smyslu článku 6 (4) nařízení (EHS) č. 2328/91\"> ID=\"1\">- Kompostované živočišné exkrementy, včetně drůbežího trusu a kompostovaného dvorského hnoje> ID=\"2\">Potřeba uznána inspekčním orgánem nebo inspekčním úřadem\"> ID=\"2\">Označení druhu zvířat\">\nČlánek č. 2328/91 (1), jak byl naposledy změněn nařízením (ES) č. 3669/93 (2) uvádí následující:\n- Sušený hnoj ze dvora a dehydratovaný drůbeží trus:\nPoužití těchto materiálů musí být uznáno inspekčním orgánem nebo inspekčním úřadem. Je nutné uvést druh zvířat, ze kterého hnoj pochází. Materiál může být použit pouze v případě, že pochází z extenzivního chovu a použití je v souladu s článkem 6 (4) nařízení (EHS) č. 2328/91.\n- Kompostované živočišné exkrementy, včetně drůbežího trusu a kompostovaného dvorského hnoje:\nTaké pro použití těchto materiálů je nutné získat uznání inspekčním orgánem nebo inspekčním úřadem. Je třeba uvést druh zvířat, z kterého exkrementy pochází."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: (ES) Domnule președinte, mulțumesc mult, domnule comisar, pentru răspunsul dvs. care va fi foarte util pentru summitul Uniunea Europeană-Africa.\nTranslate this to English?\n\nAssistant: Step 1: Identify the language of the text. In this case, the language is Romanian.\nStep 2: Understand the words and phrases in this Romanian text. \nDomnule președinte - Mr. President\nmulțumesc mult - thank you very much\ndomnule comisar - Mr. Commissioner\npentru - for\nrăspunsul - the response\ndvs - your\ncare - which\nva fi - will be\nfoarte util - very useful\npentru - for\nsummitul - the summit\nUniunea Europeană-Africa - European Union-Africa\nStep 3: Translate the text into English based on the understanding of the words and phrases.\n\"Mr. President, thank you very much, Mr. Commissioner, for your response which will be very useful for the European Union-Africa summit.\"\nIn this translation, I have arranged the words and phrases in an order that makes sense in English while maintaining the intended meaning of the text."} +{"text": "System: Explain how you used the definition to come up with the answer.\n\nUser: Instructions: Given a sentence, generate a most unlikely next statement. The next statement should not be logically relevant to the given context\nInput: Someone's platoon is attacked. A bullet\nOutput:\n\nAssistant: The kitchen sink needs to be replaced.\nIn order to come up with an unlikely next statement, I focused on creating a statement that is not logically connected to the original sentence. The original sentence involves a dangerous attack situation, so I introduced a mundane, unrelated topic (kitchen sink) to make the statement irrelevant to the initial context."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: What is the title of this article:\nThe second decline was not as complete as the first. Thousands of people had learned to play the instrument. Even as the second wave of mandolin popularity declined in the early 20th century, new versions of the mandolin began to be used in new forms of music. Luthiers created the resonator mandolin, the flatback mandolin, the carved-top or arched-top mandolin, the mandolin-banjo and the electric mandolin. Musicians began playing it in Celtic, Bluegrass, Jazz and Rock-n-Roll styles — and Classical too.\n\nAssistant: The title of this article could be: \"The Evolution and Continued Popularity of the Mandolin in Music.\""} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: What is the answer to this question? who wrote the night they drove ole dixie down?\n\nAssistant: The song \"The Night They Drove Ole Dixie Down\" was written by a man named Robbie Robertson. He was part of a music group called \"The Band\". It's like when you and your friends make up a fun song together, Robbie did the same with his friends in \"The Band\"."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Please answer the following question: Information: - Japan (\"Nippon\" or \"Nihon\" ; formally \"\" or \"Nihon-koku\", means \"State of Japan\") is a sovereign island nation in Eastern Asia. Located in the Pacific Ocean, It is lying off the eastern coast of the Asia Mainland (east of China, Korea, Russia) and stretching from the Sea of Okhotsk in the north to the East China Sea and near Taiwan in the southwest. - Plot. The story focuses on Kazuma Azuma, a boy on his quest to create \"\"Ja-pan\"\", a national bread for Japan itself. He initially heads to Tokyo to expand his horizons at the bread-making chain Pantasia; the story continues with Azuma's exploits with his other coworkers. - Utatane is the first album from the Japanese duo RYTHEM, released in 2004. - In January 2009, the series went on hiatus due to Tobe's illness and remained uncompleted when she died in January 2010. Her works during her illness were collected and published with some of her earlier works as volume 15 in June 2010. - \"Naruto\" was serialized in \"Weekly Shnen Jump\" magazine from the 43rd issue of 1999 to the 50th issue of 2014, with the chapters collected into seventy-two \"tankbon\" volumes by Shueisha. The manga was later adapted into a television anime, which was produced by Studio Pierrot and Aniplex. It premiered across Japan on the terrestrial TV Tokyo and other TX Network stations on October 3, 2002. The first series lasted 220 episodes, while \"\", a sequel to the original series, has been airing since February 15, 2007. In addition to the anime series, Studio Pierrot has developed eleven movies and several original video animations. Other types of merchandise include light novels, video games, and trading cards developed by several companies. - Mugen Factory ( , Infinite Factory ) is the second album from the Japanese duo Rythem . It was released on May 24 , 2006 under Sony Music Entertainment Japan . The album contains songs from their 6th up to 10th single . The item 's stock number is AICL - 1746 . - History. Rythem's talents as singer/songwriters, with flowing melodies and harmonies, began to shine as they worked with their own original tune \"Hitoritabi Shararuran\" in the Sony Music Audition. In May 2003, while they were still high school seniors, Rythem's first single \"Harmonia\" was released and was used as the second ending theme song for hit TV TOKYO anime series, \"Naruto\". Following the successful release of their debut single, 4 more singles including the theme song for \"With the Light\", and the debut album, \"Utatane\" were released before \"Hki Gumo\" was released in early 2005 as a single for the first opening theme of yet another hit TV TOKYO anime series, \"Yakitate!! Japan\". What is the relationship between 'mugen factory' and 'utatane'?\nA:\n\nAssistant: The relationship between 'Mugen Factory' and 'Utatane' is that they are both albums from the Japanese duo RYTHEM. 'Utatane' is their first album, released in 2004, while 'Mugen Factory' is their second album, released on May 24, 2006, under Sony Music Entertainment Japan. Both albums showcase RYTHEM's musical talents, featuring their unique melodies and harmonies in various songs."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Here's a complex question that requires someone to reason about the input, can you answer it? Gerard Brantsen was one of the Dutch signatories to the treaty ending a war that began in what year?\nThe answer to this question is:\n\nAssistant: To answer this complex question, we must first examine the historical figure Gerard Brantsen and identify which war was ended with the treaty he signed. \nGerard Brantsen (1735-1809) was a Dutch diplomat and politician in the 18th-19th century. His most significant contribution was to the Treaty of Paris in 1783, which marked the end of the Fourth Anglo-Dutch War.\nThe Fourth Anglo-Dutch War was a conflict between the Kingdom of Great Britain and the Dutch Republic, and commenced on December 20, 1780. The war was primarily driven by commercial rivalries and political tensions that emerged from the American Revolutionary War. At this time, the Dutch Republic, in an attempt to protect their trade interests, sought to provide military assistance and resources to the newly established United States of America. This support of the rebellious colonies angered Great Britain and consequently led to the escalation of hostilities between the two European powers.\nThe Treaty of Paris, which ended the Fourth Anglo-Dutch War, was signed on September 3, 1783. The negotiations brought together the major powers involved in the American Revolutionary War. Delegates from the United States, France, Spain, and the Dutch Republic were present to deliberate the terms for reestablishing diplomatic relations and resolving territorial disputes. Gerard Brantsen, acting as one of the Dutch envoys, played a significant role in the negotiation process.\nIn conclusion, Gerard Brantsen was one of the Dutch signatories to the Treaty of Paris, which terminated the Fourth Anglo-Dutch War. This war began on December 20, 1780."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: This is the content of an email: Julie Markantonis has her ladder over at my house (908 Bomar). Not sure how tall it is, but it is very tall. She says you can use it if it will work. \nWhat was the subject line for this email?\n\nAssistant: \"Julie's Tall Ladder Available for You at 908 Bomar\""} +{"text": "User: Process: - An owl leaves his nest - The owl flys out into the night - The owl looks down at the ground - The owls keen eyesight allows him to see in the dark - The owls sensitive hearing allows him to locate hidden prey - The owl swoops down - The owl catches a mouse in his talons. Question: suppose owls have better eyesight happens, how will it affect owls see more prey. How does the supposed perturbation influence the second effect mentioned. Answer by more, less or no effect\nAnswer:\n\nAssistant: more effect"} +{"text": "User: Process: - Solar radiation reaches Earth's atmosphere - Some is reflected back into space - The rest of the energy is absorbed by land and oceans, heating the Earth - Heat radiates from Earth to space - Some of this heat is trapped by greenhouse gases in the atmosphere - Human activities such as burning fuel are increasing the amount of greenhouses gases - This causes more heat radiation to stay trapped in our atmosphere - The Earth's temperature rises - Question: suppose Start to clean the air and store pollution somewhere else. happens, how will it affect a GREATER greenhouse effect. - A: more - B: less - C: no effect\n\nAssistant: B: less"} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Please answer the following question: Title: Buy the real one...written when she was alive, before 1947. Review: Read the real stories by the real Lucia here on Amazon, not the ones propagated by the Vatican straw man. Read links here: [...] Does this product review convey a negative or positive sentiment?\nA:\n\nAssistant: This product review conveys a negative sentiment. The reviewer appears to be discouraging potential buyers from purchasing the product in question, suggesting that they should instead opt for the \"real stories by the real Lucia\" available on Amazon. The mention of the \"Vatican straw man\" implies that the reviewer believes the current product may contain distorted or inaccurate information. Furthermore, the inclusion of external links reinforces the reviewer's intent to redirect potential buyers to alternative sources."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Title: My favourite book on bidding Review: This is a fantastic book - although quite slim, there is lots of useful content. Using the principles it contains, I was one of the few that found the 24 point slam in the British Bridge League simultaneous pairs in January 1998. Is the review positive or negative?\nAnswer:\n\nAssistant: The review is positive. The reviewer describes the book as \"fantastic\" and appreciates its useful content. They also mention how the book's principles helped them succeed in a British Bridge League event, which indicates their satisfaction and positive experience with the book."} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Please answer the following question: Write a multi-choice question for the following article, with the given choices and answer: Article: Barefoot and dirty, the girl just sat and watched the people go by. She never tried to speak. Many people passed, but never did one person stop. The next day I decided to go back to the park, curious to see if the little girl would still be there. Right in the very spot as she was yesterday, she sat with the saddest look in her eyes. Today I was to make my own move and walk over to the little girl. As we all know, a park full of strange people is not a place for young children to play alone. As I began walking towards her, I could see the back of the little girl's dress indicated a deformity . I smiled to let her know it was okay, I was there to help, to talk. I sat down beside her and opened with a simple \"Hello\". I smiled and she shyly smiled back. We talked until darkness fell and the park was completely empty. I asked the girl why she was so sad. The little girl looked at me and with a sad face said, \"Because I'm different.\" I immediately said \"That you are!\" and smiled. The little girl acted even sadder. She said, \"I know.\" She looked at me and smiled. Slowly she stood to her feet, and said, \"Really?\" \"Yes, dear, you're like a little guardian angel sent to watch over all those people walking by.\" She nodded her head \"yes\" and smiled, and with that she spread her wings and said with a twinkle in her eye, \"I am. I'm your guardian angel.\" I was speechless, sure I was seeing things. She said, \"For once you thought of someone other than yourself, my job here is done.\" Immediately I stood to my feet and said, \"Wait, so why did no one stop to help an angel!?\" She looked at me and smiled, \"You're the only one who could see me, and you believe it in your heart.\" ...and she was gone. And with that my life changed dramatically. So, when you think you're all you have, remember, your angel is always watching over you. Mine is! Options: A they pretended not to see her B she pretended to be poor C they never thought of helping her D they knew the author would help her Answer: C they never thought of helping her Question:\nA:\n\nAssistant: In the article, why did people pass by the little girl without stopping to help her? \nA: They pretended not to see her\nB: She pretended to be poor\nC: They never thought of helping her\nD: They knew the author would help her"} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: \"Broken Arrow\" is a song written by Canadian singer-songwriter Neil Young and recorded by Buffalo Springfield on their 1967 album \"Buffalo Springfield Again\". It was recorded in August and September 1967 at Columbia Recording Studios and Sunset Sound Recorders. It incorporates musical ideas from \"Down Down Down,\" a demo Young recorded with Buffalo Springfield (now available on the box set).\nCan we draw the following conclusion?\n\"Broken Arrow\" is on the Buffalo Springfield box set.\n\nAssistant: Based on the information provided, we cannot directly conclude that \"Broken Arrow\" is on the Buffalo Springfield box set. The details mention that \"Broken Arrow\" is a song written by Neil Young and recorded by Buffalo Springfield on their 1967 album \"Buffalo Springfield Again.\" Additionally, it is mentioned that \"Broken Arrow\" incorporates musical ideas from \"Down Down Down,\" a demo Young recorded with Buffalo Springfield, which is now available on the box set.\nAlthough we know that the demo \"Down Down Down\" is included in the box set, we do not have sufficient information to confirm that \"Broken Arrow\" is also part of the same box set."} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Given the below context: On January 5, 1900, four friends arrive for a dinner at the London home of their inventor friend George, but he is not there. He arrives suddenly, bedraggled and exhausted, and begins describing the strange experiences he has had since they last met. At their earlier dinner on New Year's Eve, George said that time was \"the fourth dimension\". He shows David Filby, Dr. Philip Hillyer, Anthony Bridewell, and Walter Kemp a small model time machine and asks one of them to press a tiny lever on the model. The device disappears, validating his claim, but his friends remain unconvinced; their reactions vary from curiosity to frank dismissal. George bids his guests a good evening, then heads downstairs where his full-size time machine awaits. He presses a lever and moves forward through time 17 years into the future to September 15, 1917. He meets Filby's son, James, who tells him of Filby's death in a war. Saddened, he resumes his journey, stopping on June 19, 1940 during The Blitz, finding himself in the midst of \"a new war\"; George resumes his journey and his house is hit by a bomb and is destroyed. George stops on August 19, 1966 finding his neighborhood now part of a futuristic metropolis. People hurry into a nearby fallout shelter amid the blare of air raid sirens. An elderly James Filby urges George to immediately take cover. Moments later, a nuclear explosion destroys London, causing a volcanic eruption around him. George narrowly makes it back to his machine and continues his journey forward as the lava rapidly rises, cools, and hardens, trapping him inside. He travels much farther into the future until the lava erodes away, revealing a lush, green, unspoiled landscape. Guess a valid title for it!\nAnswer:\n\nAssistant: \"The Time Traveler's Odyssey\""} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Read this: In March 1861, Victoria's mother died, with Victoria at her side. Through reading her mother's papers, Victoria discovered that her mother had loved her deeply; she was heart-broken, and blamed Conroy and Lehzen for \"wickedly\" estranging her from her mother. To relieve his wife during her intense and deep grief, Albert took on most of her duties, despite being ill himself with chronic stomach trouble. In August, Victoria and Albert visited their son, the Prince of Wales, who was attending army manoeuvres near Dublin, and spent a few days holidaying in Killarney. In November, Albert was made aware of gossip that his son had slept with an actress in Ireland. Appalled, Albert travelled to Cambridge, where his son was studying, to confront him. By the beginning of December, Albert was very unwell. He was diagnosed with typhoid fever by William Jenner, and died on 14 December 1861. Victoria was devastated. She blamed her husband's death on worry over the Prince of Wales's philandering. He had been \"killed by that dreadful business\", she said. She entered a state of mourning and wore black for the remainder of her life. She avoided public appearances, and rarely set foot in London in the following years. Her seclusion earned her the nickname \"widow of Windsor\".\nDuring which year did Victoria's mother die?\nWhat is the answer? (If it cannot be answered, return \"unanswerable\")\n\nAssistant: Victoria's mother died in the year 1861."} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Question: Here's a complex question that requires someone to reason about the input, can you answer it? What american sports wrote a regular column Spotlight on sports and was featured in the non profit faith based magazine Guideposts?\nAnswer:\n\nAssistant: Answer: American sportswriter and commentator, Bill Marvel, wrote a regular column titled \"Spotlight on Sports\" and was featured in the non-profit, faith-based magazine Guideposts."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: High-definition television\nThe first regular broadcasts started on January 1, 2004 when the Belgian company Euro1080 launched the HD1 channel with the traditional Vienna New Year's Concert. Test transmissions had been active since the IBC exhibition in September 2003, but the New Year's Day broadcast marked the official launch of the HD1 channel, and the official start of direct-to-home HDTV in Europe.\nQ: What event was first broadcasted on January 1, 2004?\n\nAssistant: The first event that was broadcasted on January 1, 2004, was the traditional Vienna New Year's Concert. This was the official start of a new TV channel called HD1, which showed very clear and nice looking pictures called high-definition television in Europe."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Extract the answer to the following question from the movie plot. If the question isn't answerable, please output \"Can't answer\". Question: What does Harold bet and win on? Title: The Sin of Harold Diddlebock Movie plot: In 1923, college freshman Harold Diddlebock (Harold Lloyd) is brought into his college's football team where he scores the winning touchdown for his team (as told in 'The Freshman'). The mild-mannered Harold Diddlebock is promptly hired by the pompous advertising tycoon J.E. Waggleberry (Raymond Walburn) to work at his firm. Harold is given a low-paying job at the firm's bookeeper.Jumping forward 22 years to 1945, Harold (still working in his boring job as bookkeeper) is promptly let go by Waggleberry with nothing but a tiny pension of around $1,000. Harold bids farewell to the girl at the desk down the aisle, Miss Otis (Frances Ramsden), whom he had hoped to marry just as he had hoped to marry five of her older sisters before whom also worked at the firm.As Harold wanders aimlessly through the streets, his life's savings in hand, he has a run-in with an older man; a slimy racetrack tout named Wormy (Jimmy Conlin) and finds himself in a bar. When he tells the bartender (Edgar Kennedy) that he's never had a drink in his life, the barkeep creates a potent cocktail he calls \"The Diddlebock\", one sip of which is enough to release Harold from all his inhibitions, setting him off on a day-and-a-half binge of spending and carousing. Harold then goes to a tailor and barber shop and spends most of his money and bets the rest on a horserace and wins.Some days later, Harold wakes up on the couch in the house of his widowed older sister Flora (Margaret Hamilton). He finds that he has a hangover, but he also has a garish new wardrobe, a ten-gallon hat, a hansom horse-drawn cab complete with a driver, and ownership of a bankrupt circus.Trying to sell the circus, Harold and Wormy visit circus-loving Wall Street banker Lynn Sargent (Rudy Vallee), and then, when he turns them down, so do the rest of the town's bankers. To get past the bank guards, Harold brings along Jackie the Lion, who incites panic. Wormy brings along a thermas containing the potent 'Diddlebock' drink where they give shots to each of the bankers to get...\n\nAssistant: Harold bets the rest of his money on a horserace and wins."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: In Christianity, an archbishop (, via Latin \"archiepiscopus\", from Greek , from -, \"chief\", and , \"bishop\") is a bishop of higher rank or office. In some cases, like the Lutheran Church of Sweden, it is the denomination leader title. Like popes, patriarchs, metropolitans, cardinal bishops, diocesan bishops, and suffragan bishops, archbishops are in the highest of the three traditional orders of bishops, priests, also called presbyters, and deacons. An archbishop may be granted the title, or ordained as chief pastor of a metropolitan see or another episcopal see to which the title of archbishop is attached. \nEpiscopal sees are generally arranged in groups in which the bishop who is the ordinary of one of them has certain powers and duties of oversight over the other sees. He is known as the metropolitan archbishop of that see. In the Roman Catholic Church, canon 436 of the Code of Canon Law indicates what these powers and duties are for a Latin Rite metropolitan archbishop, while those of the head of an autonomous (\"sui iuris\") Eastern Catholic Church are indicated in canon 157 of the Code of Canons of the Eastern Churches. \nAs well as the much more numerous metropolitan sees, there are 77 Roman Catholic sees that have archiepiscopal rank. In some cases, such a see is the only one in a country, such as Luxembourg or Monaco, too small to be divided into several dioceses so as to form an ecclesiastical province. In others, the title of archdiocese is for historical reasons attributed to a see that was once of greater importance.\n1. What is the highest of the bishops?\n2. What religion has a see?\n3. How many of those have an archiepiscopical rank?\n4. Where is one country where there is only one see?\n5. What is another?\n6. Why are there only one there?\n7. Who is normally in charge of a see?\nProvide a numbered list of answers.\n\nAssistant: 1. The highest of the bishops are archbishops, who hold a higher rank or office compared to other bishops in the hierarchy.\n2. Christianity, and more specifically, the Catholic Church has a \"see\" or a diocese, which is the jurisdiction of a bishop.\n3. There are 77 Roman Catholic sees that have archiepiscopal rank.\n4. Luxembourg is one country where there is only one see.\n5. Another country with only one see is Monaco.\n6. There is only one see in these countries because they are too small to be divided into several dioceses so as to form an ecclesiastical province.\n7. A bishop is normally in charge of a see as the chief pastor with oversight powers and duties over other sees within their jurisdiction. In cases where an archbishop is appointed, the archbishop is in charge of the see."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Metal Gear Solid V: The Phantom Pain is an action-adventure stealth game developed by Kojima Productions and published by Konami. It was released worldwide for Microsoft Windows, PlayStation 3, PlayStation 4, Xbox 360 and Xbox One on September 1, 2015. It is the ninth installment in the series that was directed, written and designed by Hideo Kojima following Metal Gear Solid V: Ground Zeroes, a stand-alone prologue released the previous year, as well as his final work at Konami.\nIs it true that is metal gear solid 5 phantom pain open world?\n\nAssistant: Yes, it is true that Metal Gear Solid V: The Phantom Pain is an open-world game. The game features a large, open-world environment that players can explore freely, giving them more freedom in choosing how to approach missions and navigate the game world. This open-world design allows for a more immersive and dynamic gameplay experience compared to previous entries in the Metal Gear Solid series, which were generally more linear in structure."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Edouard Louis Joseph, baron Merckx (Dutch pronunciation: ['merks]; born 17 June 1945), better known as Eddy Merckx, is a Belgian former professional road and track bicycle racer who is widely seen as the most successful rider in the history of competitive cycling. His victories include an unequalled eleven Grand Tours (five Tours of France, five Tours of Italy, and a Tour of Spain), all of the five Monuments, three World Championships, the hour record, every major one-day race other than Paris-Tours, and extensive victories on the track. Born in Meensel-Kiezegem, Brabant, Belgium, he grew up in Woluwe-Saint-Pierre where his parents ran a grocery store. He played several sports, but found his true passion in cycling.\nMerckx opened the 1967 campaign with two stage victories at the Giro di Sardegna. He followed these successes by entering Paris-Nice where he won the second stage and took the race lead. Two stages later, a teammate, Tom Simpson, attacked with several other riders on a climb and was nearly 20 minutes ahead of Merckx, who remained in a group behind. Merckx attacked two days later on a climb 70 km into the stage. He was able to establish a firm advantage, but obeyed orders from his manager to wait for the chasing Simpson. Merckx won the stage, while Simpson secured his overall victory. On 18 March, Merckx started the Milan-San Remo and was seen as a 120-1 favorite to win the race. He attacked on the Capo Berta and again on the Poggio, leaving only Gianni Motta with him. The two slowed their pace and were joined by two more riders. Merckx won the four-man sprint to the finish. His next victory came in La Fleche Wallonne after he missed out on an early break, caught up to it, and attacked from it to win the race. On 20 May, he started the Giro d'Italia, his first Grand Tour. He won the twelfth and fourteenth stages en route to finishing ninth in the general classification. He signed with Faema on 2 September for ten years worth 400,000 Belgian francs. He chose to switch over in order to be in complete control over the team he was racing for. In addition, he would not have to pay for various expenses that came with racing such as wheels and tires. The next day, Merckx started the men's road race at the 1967 UCI Road World Championships in Heerlen, Netherlands. The course consisted of ten laps of a circuit. Motta attacked on the first lap and was joined by Merckx and five other riders. The group thinned to five as they reached the finish line where Merckx was able to out-sprint Jan Janssen for first place. In doing so, he became the third rider to win the world road race amateur and professional titles. By winning the race he earned the right to wear the rainbow jersey as world champion. Answer this question using a quote from the following article:\nDid he have any other victories?\n\nAssistant: Eddy Merckx had an impressive number of victories throughout his career. As mentioned in the original text, his wins include eleven Grand Tours, five Monuments, three World Championships, the hour record, and almost every major one-day race with the exception of Paris-Tours. Furthermore, he achieved numerous victories on the track. His incredible success has earned him the reputation of being the most successful cyclist in the history of competitive cycling, which clearly demonstrates the extent of his additional victories beyond the ones specifically discussed in the text."} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Detailed Instructions: In this task, answer 'Yes' if the frequency of the given word in the two sentences is equal, else answer 'No'.\nQ: Sentence1: 'group of pastries sitting on top of a white plate on table', Sentence2: 'a picture of a boy with two clock arms on his face'. Is the frequency of the word 'a' in two sentences equal?\nA:\n\nAssistant: No"} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Here the omnibus came up, and I rode back to Manchester. The whole conversation took up very little more time than it will take to read it; but I thought it worth recording, as characteristic of the people now suffering in Lancashire from no fault of their own. I know the people well. The greatest number of them would starve themselves to that degree that they would not be of much more physical use in this world, before they would condescend to beg. But starving to death is hard work. What will winter bring to them when severe weather begins to tell upon constitutions lowered in tone by a starvation diet--a diet so different to what they have been used to when in work? What will the 1s. 6d. a-head weekly do for them in that hard time? If something more than this is not done for them, when more food, clothing, and fire are necessary to everybody, calamities may arise which will cost England a hundred times more than a sufficient relief--a relief worthy of those who are suffering, and of the nation they belong to--would have cost. In the meantime the cold wings of winter already begin to overshadow the land; and every day lost involves the lives, or the future usefulness, of thousands of our best population. \nQuestion: \"In what town would the majority starve before condescending to beg?\"\nAnswer: \"Cardiff\"\nIs this answer to the question correct?\nChoose your answer from: (a). no (b). yes\n\nAssistant: a). no"} +{"text": "System: You should describe the task and explain your answer. While answering a multiple choice question, first output the correct answer(s). Then explain why other answers are wrong. Think like you are answering to a five year old.\n\nUser: Choose your answer from:\n --negative;\n --positive;\nShort movie review: skeeved \nDid the critic thinking positively or negatively of the movie?\n\nAssistant: --negative\nThe critic was thinking negatively of the movie. When someone says they are \"skeeved\" by a movie, it means they felt uncomfortable or disgusted by it. So, they didn't like it. Positive would mean they liked it or thought it was good, but \"skeeved\" shows they did not enjoy it at all."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Does the following review have a positive or negative opinion of the movie?\nhave completely forgotten the movie by the time you get back to your car in the parking lot\n\nAssistant: To determine if the review has a positive or negative opinion of the movie, I will analyze the content and assess the overall sentiment.\nStep 1: Identify key phrases\nThe key phrase from this review is \"have completely forgotten the movie by the time you get back to your car in the parking lot.\"\nStep 2: Assess the sentiment of the key phrase\nThe phrase implies that the movie is forgettable and not memorable.\nStep 3: Determine the overall opinion\nSince the movie was described as forgettable, the overall opinion of the review is negative."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: 1. Carney is on to something in \"Begin Again,\" some handle on a particular truth he's exploring, and he remains faithful to it through the finish.\n2. \"Begin Again\" is pleasantly predictable if you're in an undemanding mood. If you're not, it's unbearable, like hearing a treasured folk song given a Hot 97 makeover.\n3. Knightley brings emotional confusion into sharp, touching focus as the pair bond over nighttime strolls and the shared love of classic pop. Carney flips film clichés to keep us guessing.\n4. Carney deserves great credit for the movie's clever, layered structure, and for resisting a few obvious plot turns along the way. Lightning doesn't strike, but sunshine works, too.\n5. Carney is wonderful at evoking the joy of music, its healing powers, and the way it connects us to each other and the universe.\n6. Can a song save your life? And can an impossibly mushy, mawkish movie make you feel good? The answer to the second question, anyway, is yes.\n7. \"Begin Again\" may not always swing, but it makes up for that in sincerity and a welcome willingness to ambush expectations.\n8. A movie that mocks artistic sellouts and then sells out itself is like a cop who chases a criminal before accepting a bribe to let him go.\n9. Another nice musical spin by Carney on the conventions of quasi-romantic dramedy.\n10. \"Begin Again\" is an ode to the indie-rock spirit, but its music -- written by Carney and others -- is as safe, smooth and Starbucks-friendly as Norah Jones.\nWrite a one sentence summary of the reviews above.\n\nAssistant: \"Begin Again\" is a movie that explores the joy and healing power of music, with a touching story and heartfelt performances, despite some predictability and occasional lapses in its indie-rock spirit."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Answer the following question: Question: Who did Franklin involves, without his permission? If there is no answer, please output \"Insufficient information to provide an answer.\". Movie title: Ben and Me Context: The film begins with a tour group in the present day, but quickly moves to the 16th century. Quickly tracing a mouse family's history in England, we see how the family was forced to go to America due to an excess of cats.In America, the film then focuses on Amos, a 1740s mouse living in a Philadelphia church. His attempts to land a job lead him to Benjamin Franklin's printing shop. He helps Franklin erase his debts, create a successful newspaper, invent new devices, and even helps make him more popular and respected with the people of the city.After Franklin involves Amos, without his permission, in the famous kite electricity experiment that nearly kills him, Amos is outraged and leaves him. Franklin's attempts to work alone prove disastrous such as a failed diplomatic mission to the King that pushes the colonies to open armed revolt.They finally reconcile in 1776 when Franklin begs Amos to return and agrees to sign a contract drafted by the mouse for more considerate treatment. Furthermore, as Franklin examines the agreement, Thomas Jefferson arrives to ask for help for the beginning of the Declaration of Independence he is writing. Amos insists that Franklin read his own agreement out loud at that moment and its wording inspires Jefferson to create an appropriate opening to his own piece.The movie ends with a return to the tour group in the present.\nAnswer:\n\nAssistant: Franklin involved Amos, the mouse, without his permission in the famous kite electricity experiment that nearly kills him. This incident led to Amos being outraged and leaving Franklin."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Short movie review: scotland looks wonderful , the fans are often funny fanatics \nDid the critic thinking positively or negatively of the movie?\n\nAssistant: Based on the short movie review provided, it appears that the critic is thinking positively about the movie. The reviewer mentions two aspects they seem to appreciate: the beautiful portrayal of Scotland and the amusing fanatic fans featured in the film.\nThe description of Scotland as \"wonderful\" suggests that the visuals and cinematography effectively captured the beauty and essence of the country, thereby contributing to the overall positive experience for the viewer. This portrayal could have made the movie more enjoyable and appealing for anyone who appreciates stunning landscapes and picturesque scenery. It also implies that the filmmakers successfully managed to showcase the charms and allure of Scotland, turning it into a fascinating backdrop for the story.\nMeanwhile, the mention of \"funny fanatics\" indicates that the movie succeeds in delivering an element of humor through its portrayal of the fans. The critic highlights their amusement with the actions and dedication of these ardent supporters, which adds an entertaining and perhaps even relatable aspect to the film. By appreciating the humor and the portrayal of passionate fans, the reviewer is acknowledging that the movie provides a lighthearted and engaging experience for the audience.\nIn conclusion, the critic appears to be thinking positively about the movie, as they have pointed out two strong aspects of the film - the captivating visuals of Scotland and the amusing portrayal of fans. The way these elements are highlighted in the review demonstrates the critic's appreciation for the movie's ability to entertain and deliver an engaging cinematic experience."} +{"text": "User: Q: You are presented with the quiz \"What do amphibians easily absorb?\" But you don't know the answer, so you turn to your teacher to ask for hints. He says that \"amphibians have permeable skin that easily absorbs substances from the environment\" and \"chemicals can be very toxic to permeable, amphibian skin\". So, what's the best answer to the question?\nAvailable choices: [i] heat energy. [ii] water. [iii] Energy. [iv] Bodily water. [v] chemicals. [vi] body water. [vii] smog. [viii] oxygen.\nA:\n\nAssistant: [v] chemicals."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Build a movie plot around this: Who does the dagger belong to? Zeko\nThe answer to this question is:\n\nAssistant: Title: The Bloodline of Zeko\nGenre: Action, Mystery, Adventure\nPlot:\nIn the bustling city of Veloria, a new generation of a once-powerful crime family is seeking to rise to power. As members of the criminal underworld scramble to establish their dominance, a simple yet mysterious question echoes through the ranks: Who does the dagger belong to?\nThe film begins with our protagonist, Luna, an ambitious and clever journalist who inherited the same resolve from her mother - a prominent crime reporter that vanished without a trace several years ago. After receiving an anonymous tip about the Zeko criminal empire, she decides to pursue the story in search of her mother's whereabouts and delve into the dark world of organized crime.\nAs Luna delves deeper into her investigation and gradually gains the trust of the Zeko family members, she unexpectedly uncovers the legendary \"Blood Dagger of Zeko\" - a priceless artifact that once belonged to the family's powerful and ruthless patriarch, Zeko the Elder. Left in his will, the dagger was known to grant immense power and authority to whoever it ultimately belonged to, becoming the key to the family's fortune and the city's criminal empire.\nHowever, an internal power struggle within the family begins to escalate when the youngest and most ambitious sibling, Adrian Zeko, sets his eyes on the dagger. As the other siblings, Rina and Niko, frantically search for clues to secure their own claim, a race against time ensues to determine the heir to the Zeko empire.\nLuna becomes entwined in this dangerous game as she uncovers links between the dagger, the family, and her mother's disappearance. As the story unfolds, it is revealed that Luna's mother had discovered the truth about the dagger and had intended to expose the family's secrets and transgressions before her sudden disappearance. Unable to overlook the pervasive corruption any longer, Luna decides to take matters into her own hands to honor her mother's legacy and to bring down the Zeko empire.\nIn a thrilling conclusion, Luna confronts Adrian, who is on the verge of obtaining the dagger. Amidst the chaos, Luna reveals her findings, causing a shift in power dynamics. The siblings ultimately unite to dismantle the family's criminal activities and restore justice to Veloria.\nThe film ends with Luna, finally honoring her mother's memory as an esteemed crime reporter, taking ownership of the dagger and gifting it to a museum for the citizens of Veloria to witness the tale of the once-powerful Zeko family dynasty.\nThe answer to the question \"Who does the dagger belong to?\" is ultimately revealed to signify the triumph of truth, justice, and morality over a dark and corrupted past. Luna's brave pursuit for her mother's legacy results in the downfall of the Zeko family, and the dagger finally belongs to the people of Veloria in a symbolic gesture of hope and redemption."} +{"text": "User: Question: Article: There is a new park near Andy's home. It's fine today. Andy and his family are in the park now. On their left, there is a cafe. On their right, there is a big lake. There are many flowers and trees around the lake. There's a small hill behind the lake. Near the lake, there are two signs. One says, \"Don't swim in the lake!\" The other says, \"No fishing!\" But you can go boating on the lake. There is a beautiful garden in the middle of the park. There are green grass and beautiful flowers in it. There are some small shops between the lake and the garden. The park is so nice. Andy and his family like it very much. Question: Is there a cafe in the park? Yes or no, is the answer \"Yes, there is.\"?\nAnswer:\n\nAssistant: Yes, there is."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Context: Lee went and talked to a therapist and decided it was worth his time to keep going. Question: How would Lee feel afterwards? Which one of these answers best answers the question according to the context? A: desperate afterwards B: lonely afterwards C: hopeful afterwards\nA:\n\nAssistant: C: hopeful afterwards"} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: What was the name of the album released in 2002 by the band that played at Lollapalooza in 1994? Answer the above question based on the context below: In 1993, the Breeders released their second album, Last Splash. At this time, the group's lineup consisted of sisters Kim and Kelley Deal on guitar and vocals, Josephine Wiggs on bass and vocals, and Jim Macpherson on drums. Last Splash was successful in various countries worldwide, and the group toured extensively and played at Lollapalooza 1994. In November 1994, Kelley Deal was arrested on drug-related charges, and in 1995, Wiggs left to pursue other musical projects. Macpherson continued playing with Kim Deal in her side-project group, the Amps, and then in the 1996 incarnation of the Breeders, but quit the band in 1997. The Breeders' lineups for their albums Title TK (2002) and Mountain Battles (2008) included the Deal sisters, Mando Lopez, and Jose Medeles. In 2013, Wiggs and Macpherson rejoined the Deals to tour the 20th anniversary of Last Splash—the LSXX Tour.On December 31, 2013, the Breeders performed their final concert on the 60-date tour in Austin, Texas. The group enjoyed the LSXX concerts, and decided they would like to record new music together. Throughout 2014, Wiggs traveled from her home in Brooklyn, New York to Dayton, Ohio, near where Macpherson and both Deals lived. The group began practicing new material in Kim Deal's basement, including compositions by her and one by Wiggs. By August, there were three new songs they could play well, two less so, and others they had not yet practiced. Reported titles were \"Skinhead #2\", \"Simone\", \"All Nerve\", and \"Launched\". The band Neutral Milk Hotel asked the Breeders to open for them at a Hollywood Bowl concert to be held on September 18. The latter decided to go on tour leading up to this show and to perform some new compositions in preparation for their eventual recording.\nThe answer to this question is:\n\nAssistant: The album released in 2002 by the Breeders, the band that played at Lollapalooza in 1994, was \"Title TK\". \nThis information was identified by carefully reading the provided context and extracting key pieces of information. The context mentions that the Breeders played at Lollapalooza in 1994. It also separately points out that one of their albums, \"Title TK\", was released in 2002. Thus, aligning these facts together answers the question accurately."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Select from the following. -- Wear gloves and safety goggles to protect your eyes and mouth while attempting to remove the headlight unit. Be careful and gentle to avoid damaging the surrounding oil can.; -- Scrape any dirt off the entire bumper, including plastic parts and upholstery. Pull down the dashboard as much as possible to expose any exposed headlight casing.; -- Again, there is no need to completely remove the dome light, as the headlight unit doesn't pop away when dropped. Instead, you have to completely break the bumper and remove the lens.; -- Separate the wiring connector from the beam unit, and disconnect the wires from the light bulb socket. Hold the lamp unit with your thumb to control the spring, and remove the lens-retaining spring.;Choose from options above and answer: What most naturally follows?\nHow to remove beetle headlights \nDisconnect the negative battery cable to eliminate the risk of electrical shock. \nSometimes battery cable colors vary, but normally you will find that the red battery cable is positive (+) and the black battery cable will be negative (-). \nUse a screwdriver to loosen the screw that retains the rim, and then pull out the entire headlight unit.\nAnswer:\n\nAssistant: Alright, little buddy, let me explain it to you: To take off the beetle headlights, first, you need to be very careful and not touch the wires because they can give you a little shock. The beetle has two wires, one is red, and the other one is black. You should ask a grown-up to help you take off the black wire. This will make the car safe to work on!\nNext, you'll need a tool called a screwdriver to remove a small round thing called the rim around the headlight. Once that's done, you can carefully pull out the whole headlight like a puzzle piece. And that's how you remove the beetle headlights!"} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: What is the sentiment of the following tweet?\nTweet:i see Miley on the red carpet [online] she is awesome!!!\n\nAssistant: To analyze the sentiment of the tweet, let's break it down step-by-step:\nStep 1: Identify the main components of the tweet\n- Subject: Miley (as in Miley Cyrus, a celebrity)\n- Context: on the red carpet\n- User's opinion: \"she is awesome!!!\"\nStep 2: Determine the sentiment associated with each component\n- For the subject and context, there's no inherent sentiment, as they're neutral pieces of information.\n- The user's opinion, however, is very positive, indicated by the words \"awesome!!!\" and the enthusiastic exclamation marks.\nStep 3: Combine the sentiment values and reach an overall sentiment\nSince the user's opinion is the main sentiment-carrying element of the tweet, and it is quite positive, we can determine that the overall sentiment of the tweet is positive.\nConclusion: The sentiment of the tweet about Miley Cyrus on the red carpet is positive, as the user expresses enthusiasm and admiration for her."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Sentence: name = The Vaults, eatType = pub, priceRange = moderate, customer rating = 1 out of 5, near = Café Adriatic\nCan you represent the content in this sentence in data form?\n\nAssistant: Certainly, the content in the given sentence can be represented in a structured data form, such as a dictionary or JSON format. Here's a representation of the content in JSON format:\n```JSON\n{\n \"name\": \"The Vaults\",\n \"eatType\": \"pub\",\n \"priceRange\": \"moderate\",\n \"customerRating\": {\n \"rating\": 1,\n \"maxRating\": 5\n },\n \"near\": \"Café Adriatic\"\n}\n```\nThis JSON object represents the data as key-value pairs, where information about the given establishment, The Vaults, is organized in a structured manner. The object includes details on the eatType (pub), priceRange (moderate), customer rating (1 out of 5), and the nearby location (Café Adriatic). The customer rating is represented as a nested object, with a rating and maxRating for more clarity. This structured data format makes it easier to process, analyze, and manipulate the information."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Do these questions have the same meaning?\nWhat should be my research rationale to do my research abroad? I'm doing my project on cloud computing. What should I write and how should I write it?\nI want to do a project in cloud computing. What is an easy and feasible topic for my project?\nSelect from: 1. no; 2. yes;\n\nAssistant: 1. no\nThe two questions have different meanings. The first question is about writing a research rationale for conducting research abroad related to a cloud computing project. It asks for advice on what to include in the rationale and how to write it.\nThe second question is looking for an easy and feasible topic for a cloud computing project, without any reference to conducting research abroad or writing a research rationale."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Are these paraphrases?\nSaghar has written over 2,000 songs for many singers and music directors for Pakistani films , radio and TV .\nFor many singers and music directors , Saghar has written more than 2,000 songs for Pakistani films , radio and TV .\n\nAssistant: Yes, these two sentences are paraphrases of each other. Here is the justification:\n1. Both sentences convey the same primary information: Saghar has written over 2,000 songs for various artists and platforms, specifically for Pakistani films, radio, and TV.\n2. The sentences are structured differently, but they are semantically equivalent. The first sentence begins with Saghar as the subject and emphasizes the number of songs, while the second sentence begins with the intended audience of the songs (singers and music directors) and still mentions the impressive volume of Saghar's work. The different structures make them paraphrases of each other, not identical sentences."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Giovanni Marks (born October 2 1978) better known by his stage name Subtitle is a rapper and producer based in Los Angeles California. He is one half of the duo Lab Waste alongside Thavius Beck. He has also collaborated with other artists such as K-the-I??? Busdriver and Islands. Given a list of categories: company, educational institution, artist, athlete, office holder, mean of transportation, building, natural place, village, animal, plant, album, film or written work, what category does the paragraph belong to?\nThe answer to this question is:\n\nAssistant: To determine the appropriate category for the paragraph, I will analyze its content step-by-step.\n1. The paragraph introduces a person named Giovanni Marks, who has a stage name, Subtitle. This suggests that the person might be involved in the entertainment industry or arts.\n2. It mentions that Subtitle is a rapper and producer based in Los Angeles, California. This confirms that the person belongs to the music industry, thus aligning with the artist category.\n3. The paragraph cites his role in the duo Lab Waste and collaborations with other artists. These details further support the involvement of Subtitle in the music industry as an artist.\nAfter analyzing the paragraph, I can confidently say that it belongs to the \"artist\" category. So, the answer to this question is: artist."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Given the question: Last week Johnny had a very busy weekend. On Saturday morning, he cleaned his room. That made his mother very happy. In the afternoon, he did his math homework. It was not difficult, so it only took him one hour to finish the homework. And in the evening, he went to visit his aunt with his parents. They had a big dinner there. The next morning, he went swimming after getting up. Usually he goes swimming twice a week. He loves it very much. After lunch he went to school playground and played football. That exercise makes him in a good health. On Sunday evening, he watched TV for an hour at home and then put the books in his schoolbag for the next day. What did Johnny do after lunch on Sunday? A) He played basketball. B) He played football. C) He went swimming. D) He had a rest.\nThe answer is:\n\nAssistant: B) He played football.\nLast week Johnny experienced a busy weekend, engaging in various activities in his available time. On Saturday, he tackled cleaning his room in the morning—a task that pleased his mother. Later on, in the afternoon, Johnny focused on completing his math homework. Fortunately, he found the work relatively easy, and as such, it only required an hour of his attention.\nIn the evening, Johnny accompanied his parents on a trip to visit his aunt. While there, they all enjoyed a large dinner together. On Sunday, Johnny began his day with swimming, a sport he is particularly fond of, and typically participates in twice a week. Afterward, he returned home for lunch.\nPost-lunch, Johnny took advantage of the school playground, engaging in a game of football. Engaging in such exercise not only promotes his overall good health but also aligns with his recreational interests. Finally, on Sunday evening, Johnny spent an hour watching television before he proceeded to organize and pack his school books into his schoolbag in preparation for the following day.\nGiven the presented options, it is clear that Johnny engaged in football after enjoying lunch on Sunday."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Summarize this article:\nPhoto: AP \n \n Donald Trump has announced that New Jersey Governor Chris Christie will serve as chairman of his transition team, identifying personnel who would become part of a Trump White House. Christie will be working closely with Trump’s son-in-law Jared Kushner, whose father he sent to prison. \n \n Just months ago, Trump accused Christie of having committed a federal crime in covering up the Bridgegate scandal. Now, NBC News reports, Christie will bring “substance” and “seriousness” to the Trump campaign: \n \n Sources familiar with the planning say the transition team appointment shows “the role the governor is playing as the serious person around Trump who is trying to provide some discipline” to his growing organization. \n \n Christie also continues to act as “a conduit between the establishment and Trump” based on Christie’s relationships with Republican elected officials across the country, sources said. \n \n The New York Times reported last week that Kushner—Trump’s son-in-law and a real estate heir in his own right—would be a part of the very forward-looking transition team. (President Obama didn’t announce his Transition Team Chariman until the day after he won the 2008 election.) In 2005, as U.S. Attorney for the state of New Jersey, Christie successfully prosecuted Kushner’s father, Charles, for “crimes of greed, power, and excess.” \n \n There does not seem to be any bad blood between Christie and Jared Kushner, however: In November, New Jersey awarded $93 million in tax breaks to a Kushner Companies high-rise in Jersey City. \n \n Asked why the Trump campaign was directing press inquiries about the transition committee to the governor’s office, a spokeswoman for Christie told Gawker, “We’re trying to figure that out ourselves.” ||||| If Christie is truly heading Trump’s day-to-day transition effort, the answer to both questions has to be no, said Max Stier, who is chairman of the Partnership for Public Service and has been advising the campaigns on transition planning. “If Christie is going to really run the transition process, he can’t effectively be active in the campaign or be himself a candidate because it’s a full-time job,” Stier said. In 2012, Mitt Romney tapped former Utah Governor Mike Leavitt to run his transition, and Leavitt kept his distance from the campaign while building out what he called a 600-person “government in miniature.” It’s not clear if Trump plans a similar operation. If Christie’s role atop the transition team was more ceremonial, it might be possible for him to split time with the campaign, to stay in contention to be Trump’s running mate—or even to do the job he was elected to perform, governing the state of New Jersey. \n \n The announcement came just a few days after The New York Times reported that Trump had asked his son-in-law, Jared Kushner, to head up the transition efforts. Kushner was not mentioned in the press release announcing Christie’s new job. If Monday was an early indication, Christie seems to be operating pretty independently from the campaign: A Trump spokeswoman directed questions about his role to Christie’s gubernatorial office, which promptly referred those queries back to Trump’s campaign. \n \n Hillary Clinton, the likely Democratic nominee, has not publicly announced a transition chief. And the fact that Trump made his own announcement so quickly is at least a momentary relief to those who wondered how seriously he would take the process. “First and foremost, it is really important that Trump has publicly named a transition chairman,” Stier said. “Trump is demonstrating his understanding that he needs to prepare for the possibility that he will be governing at the same time as he is trying to win the campaign.” \n \n We want to hear what you think about this article. Submit a letter to the editor or write to letters@theatlantic.com. |||||\n\nAssistant: Donald Trump has selected New Jersey Governor Chris Christie to serve as chairman of his transition team, which means he will be tasked with identifying personnel to become part of a Trump White House. This appointment is seen as bringing \"substance\" and \"seriousness\" to Trump's campaign. The role demonstrates that Christie is seen as the \"serious person\" around Trump who is trying to provide some discipline to his growing organization. Christie will be working with Trump's son-in-law Jared Kushner, despite the fact that Christie sent Kushner's father to prison in 2005. The appointment comes just a few days after it was reported that Trump had asked Kushner to head up the transition efforts, and demonstrates Trump's belief that he must prepare for the possibility he will be governing at the same time as he is trying to win the campaign."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Please answer the following question: Melody Barnes - Melody C. Barnes (b. April 29 1964 Richmond Virginia) was chosen by President Barack Obama to serve as the Director of the Domestic Policy Council for his administration. Obama announced her appointment on November 24 2008. She was previously Executive Vice President for Policy at the Center for American Progress and led the CAP's Faith and Progressive Policy Initiative. She left the CAP in June 2008 to work on the Obama campaign. Given a choice of categories company, educational institution, artist, athlete, office holder, mean of transportation, building, natural place, village, animal, plant, album, film or written work, the text refers to which one?\nAnswer:\n\nAssistant: The text refers to an \"office holder,\" as Melody C. Barnes is described as having served in various roles, such as the Director of the Domestic Policy Council for President Barack Obama's administration as well as Executive Vice President for Policy at the Center for American Progress."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: There's always next year \n \n 21 Mar 14 \n \n After 25 games, the billion dollar bracket dream ends \n \n Just 25 games into the NCAA tournament, the billion dollar bracket bubble has already burst. \n \n Going into the third wave of games on the second day of the NCAA tournament, only three perfect brackets remained in the Quicken Loans billion dollar bracket contest. And all three had George Washington beating Memphis. \n \n As Maurice Creek’s game-tying three came up short, so went the retirement plans of the three remaining hopefuls. \n \n Memphis win means that zero perfect brackets remain in the #BillionDollarBracket – Warren Buffett and @cavsdan are celebrating — Daryl Ruiter (@RuiterWrongFAN) March 22, 2014 \n \n Billionaire Warren Buffett couldn’t have picked a better year to sponsor such a contest. Five double-digit seeded teams have won games with four games remaining in the second round. \n \n Even without a particularly mad March, the odds of picking all 63 games correctly are long. The chances of doing so was tabulated by a DePaul University math professor to be one-in-9.2 quintillion. \n \n The prospect of someone taking home the billion dollar prize took a big hit after the very first game of the tournament. About 84 percent of contestants incorrectly picked Ohio State to beat Dayton in the opener. \n \n The biggest upset, by far, was Mercer’s win over Duke. Over 97 percent of contestants picked the Blue Devils. \n \n While the billion dollar dream is over, the top-20 brackets will receive $100,000. ||||| Mercer forward Jakob Gollon (20) and guard Langston Hall (21) celebrate after defeating Duke 78-71 in an NCAA college basketball second-round game against Duke, Friday, March 21, 2014, in Raleigh, N.C.(AP... (Associated Press) \n \n Stanford head coach Johnny Dawkins yells to his players during the second half of a second-round game in the NCAA college basketball tournament, Friday, March 21, 2014, in St. Louis. Stanford won the... (Associated Press) \n \n Memphis guard Chris Crawford (3) reacts to a Mercer goal against George Washington during the second half of an NCAA college basketball second-round tournament game, Friday, March 21, 2014, in Raleigh.... (Associated Press) \n \n Gonzaga guard David Stockton pulls in a loose ball during the second half in a second-round game against Oklahoma State in the NCAA men's college basketball tournament Friday, March 21, 2014, in San Diego.... (Associated Press) \n \n A second day of upsets ended any chance of someone having a perfect NCAA tournament bracket in Warren Buffet's $1 billion challenge. It was a favorite that provided the first blemish on the final three people's brackets in the Quicken Loans contest on the Yahoo Sports website. \n \n All three had ninth-seeded George Washington beating Memphis. The Tigers won 71-66. \n \n \"If Warren Buffett wants to donate the (billion) to our university, we will take it and use it in good company,\" Memphis coach Josh Pastner said. \"We'll find a way.\" \n \n It only took 25 games for everyone to be eliminated. Then again most of brackets were knocked out on the tournament's first full day. The number of unblemished brackets kept dwindling after third-seed Duke, sixth-seed UMass and seventh-seed New Mexico lost Friday. \n \n Only 16 people remained perfect after 10th-seeded Stanford topped New Mexico. Then Tennessee routed UMass, leaving only six people with a chance of beating the 9.2 quintillion-to-1 odds. \n \n Gonzaga's victory over Oklahoma State cut that down to the final three. \n \n Even though no one won the $1 billion, the top 20 scores will still each get $100,000. \n \n Quicken Loans, which is sponsoring and insuring the Buffet contest, said on its Twitter feed that it wouldn't reveal the number of entrants to the challenge. The pool was supposed to be capped at 15 million entries. It probably wouldn't have mattered if they had let more people join. \n \n At CBSSports.com, only 0.03 percent of entrants were still perfect after Mercer upset Duke. They didn't last much longer as Tennessee's rout of UMass wiped out all the remaining unblemished entries. It took 21 games to end everyone's hope of perfection this year. Last season it took 23 games and 24 in 2012. \n \n A year ago, not a single person of the 11 million who entered on ESPN's website was perfect after a first day filled with upsets. Just four got 15 out of 16 right. \n \n This year people lasted a little longer. After 28 games, all 11 million entries had at least one mistake. \n \n ___ \n \n Follow Doug on Twitter at http://www.twitter.com/dougfeinberg |||||\nWrite a summary.\n\nAssistant: In 2014, the billion-dollar bracket dream for the NCAA tournament ended just 25 games into the competition. Only three perfect brackets remained before the third wave of games, and all three had George Washington beating Memphis. Unfortunately, as Maurice Creek's game-tying three missed, so did the billion-dollar hopes. Despite no one winning the ultimate prize, the top 20 scores received $100,000. The odds of correctly picking 63 games were calculated to be one-in-9.2 quintillion by a DePaul University math professor."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Article: Dear All, Happy New Year! A list of important dates for this term is attached to this letter. Interviews for our April and September Primary One classes will take place on Tuesday and Wednesday,18th&19th January.Now is the time to apply for admission to Peak School. Children are expected to wear their uniforms this term.As is usual at this time of the year,mornings at the school are quite cold and l would suggest children wear their sweaters to school.As the temperature goes up during the day,children tend to take off their outer clothes,so it is important that all sweatshirts are labeled . I have received a few requests from parents asking if their children can wear summer uniforms this term.I am of the view that all children in the school should follow the school policy and be dressed in the appropriate uniforms.In the longer term,a number of US at school have questioned the need to have a winter and summer uniform.Perhaps.with careful consideration,we can develop a uniform that will serve children for the whole academic year. As the foggy mornings start to appear,a number of parents have raised concerns over the safety of those children who walk to school. Our dark green uniforms are not particularly visible to motorists.Hopefully, the \"Uniform Committee\" to be set up this term will deal with this particular issue.I will keep parents informed of the channels through which views may be expressed in future newsletters. We still have a \"No Hat,No Play\" rule on the playground,so children need to ensure that they have their sunhats in school at all times. Yours sincerely, P D Tunesi Principal Question: What do we know about the present school uniforms? Yes or no, is the answer \"The school is considering changing them.\"?\n\nAssistant: From the article, we know that the present school uniforms include separate winter and summer uniforms, and that the winter ones are dark green in color. Some parents have questioned the need for both winter and summer uniforms and have suggested developing a uniform that can be worn the entire academic year. Additionally, there are concerns about the visibility of the dark green uniforms for children walking to school in foggy mornings.\nThe answer is \"Yes, the school is considering changing them,\" as the principal mentions that a \"Uniform Committee\" will be set up to deal with the issue of visibility, and that there is a possibility of developing a uniform to serve for the whole academic year after careful consideration."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Answer the following question: Title: another g-shock for me Review: This is probably my third G-shock over 25 years. A good worker's watch. I also think it is a nice looking \"rounded appearance\" every day watch. Years ago they were square. This is the only watch I wear. My first two watches never stoped running, just the rubber/plastic casing gets broken up over the years from contact with varius items I'm working on. The battery lasts 4-5 years if you dont use the light feature button much. The bands also need replacing over the years with heavy use. The three alarms I never use (high tones) because of my hearing loss. Although they claim water resistant I don't take it in the water. One did leak on me once but after it dried out it continued to work. Less than $50.00 including tax and shipping. Does this product review convey a negative or positive sentiment?\nAnswer:\n\nAssistant: The product review conveys a predominantly positive sentiment towards the G-shock watch. The reviewer shares their long-term experience with the product, stating that this is their third G-shock over a period of 25 years, which implies a sense of loyalty and satisfaction with the brand. Additionally, the reviewer mentions the watch's durability, as the first two watches never stopped running and the battery life can last 4-5 years if the light feature button is not used frequently. The watch's design and rounded appearance are also considered as appealing for an everyday watch.\nHowever, there are a few negative remarks mentioned as well. The reviewer notes that the rubber/plastic casing tends to get damaged, and the band needs to be replaced after years of heavy use. Furthermore, they reveal their disappointment in the watch's alarm feature, as the high tones cannot be heard due to their hearing loss. Lastly, the reviewer shares a concern related to water resistance, recounting an incident when the watch leaked, but acknowledges it continued working after drying out.\nOverall, the sentiment of the review is positive, as the user appears to be content with the watch's performance and design while acknowledging a few flaws. The fact that the reviewer continues to purchase G-shock watches after 25 years is a strong indication of their satisfaction with the product."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Write some highlights for the following article:\nManned by a pair of trained officers, a large private security jeep moves slowly past the high walls lining a residential street in the suburbs of Johannesburg, South Africa. Heavily armed and alert, the officers scan the designated area, looking for any suspicious activity. They work for CSS Tactical, a private security company hired by local residents to patrol their area and protect them from any potential criminals. \"To be proactive, we need to be noticed,\" says Ricky Croock, managing director of CSS Tactical. \"We need people to say 'hold on, there is a dedicated tactical vehicle.' \"A tactical vehicle is a bigger vehicle -- it is more aggressive, it's a deterrent. So the first point is we want to deter as opposed to catch crime.\" The security officers are well armed. They carry semi-automatic weapons and hand guns and they're trained in how and when to use them. And while only the police have the power of random stop and search in public places, security companies also stop people they regard as suspicious. From April 2011 to March 2012 there were 15,609 murders in South Africa, as well as 64,514 sexual offenses and 101,203 cases of aggravated robbery. According to the country's latest \"Victims of crime survey,\" 57% of respondents felt that housebreaking/burglary was the crime most feared in their area of residence. In recent years, South Africa has stepped up efforts to tackle crime, one of the country's worst social ills. Last September, the South African Police Service said that serious crime in the country had been reduced by 31.8% from 2004/5 to 2011/12. But 2011 to 2012 saw a reduction in serious crime of just 1.9%. \"We only have the official crime statistics to go by,\" says Rudolph Zinn of the Institute for Security Studies in South Africa. \"If you look at what the police would report annually in September of every year, we see over the last three or four years that crime has leveled off. \"The more concerning factor is the high percentage of so-called violent or contact crime, which is about a third of all the crime reported to the South African Police Service.\" And that's what has fueled the boom in the country's private security industry, crime experts say. There are nearly 9,000 companies and 400,000 registered active private security guards. That's more than the police and army combined, according to South African officials. \"The security industry is bigger than what it has ever been in South Africa,\" says Zinn. \"I think the growth in the industry is definitely attributed to the fact that, let's call it a weak policing or ineffective policing, and it created the opportunity for private individuals to move into the market.\" Some of those individuals have a military background. Waal de Waal, chief operating officer at South African-owned Protea Coin, spent 18 years in the special forces and military intelligence before joining the private security sector a decade ago. He is today one of the 17,000 people employed by Protea Coin. Part of the multi-million dollar company's work is dealing with small-scale crime, like recovering fuel stolen from a petrol station. The company, which also provides security at many of South Africa's mines and has been kept busy by the recent strikes in the sector, offers a range of services, including 24-hour property monitoring and protection against organized criminal syndicates. \"The role of private security companies is distinctively different to that of the police,\" says Petrus Van Niekerk, the chief executive of Protea Coin. \"We view our role as private security to really aid and support the police.\" But what happens when private security companies take matters into their own hands? The South African government claims that it's an industry that threatens national security and is determined to tighten its regulation. \"A very dangerous situation arises if you have a security industry that outstrips both your police and your army and there is completely no regulation of that industry,\" says Jenni Irish-Qhobosheane, the head of the Civilian Secretariat for Police. \"I think the public can find themselves in situations where, basically, they are put at the mercy of private security companies who are unprofessional,\" she adds. One measure considered by the country's government is that private security firms must be owned by South Africans. \"This type of move is going to frighten away foreign direct investment and I think that is a danger,\" says Hanes Venter, sales director at the South African arm of G4S, a London-listed global security giant. \"Foreign direct investment is already down year on year but this is a worry from that perspective for me.\" Whatever regulations the South African government eventually introduces, it seems unlikely the demand for private securities companies will diminish amid high unemployment rates, as well as widespread poverty and simmering industrial tensions.\n\nAssistant: A growing private security industry in South Africa has emerged due to high rates of violent crimes and perceived inefficiency of local police. CSS Tactical, a private security company, is one example of a firm now hired by Johannesburg residents to patrol their neighborhoods as they aim to deter crime. Recent statistics showed that in 2011 to 2012, South Africa experienced 15,609 murders, 64,514 sexual offenses, and 101,203 cases of aggravated robbery. The escalation of crime in the nation has driven the demand for private security companies.\nSouth Africa currently has nearly 9,000 security companies, with 400,000 registered active private security guards in total. This number is larger than the combined number of the police and army, according to South African officials. The government claims that this expanding industry presents a threat to national security and intends to tighten its regulation. One proposed measure is that private security firms must be owned by South Africans to limit potential manipulation by foreign entities. Although crime rates have seen a reduction in recent years, it is unlikely that the demand for private security firms will decrease due to high unemployment and poverty rates, as well as ongoing industrial tensions."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Please answer the following question: Combine facts and answer this: Which 2006 British-American mockumentary comedy film was Bad Grandpa a reminiscent?\nAnswer:\n\nAssistant: The 2006 British-American mockumentary comedy film that \"Bad Grandpa\" is reminiscent of is \"Borat: Cultural Learnings of America for Make Benefit Glorious Nation of Kazakhstan.\" Directed by Larry Charles and starring Sacha Baron Cohen as the titular character, Borat Sagdiyev, this film follows the journey of a Kazakh journalist as he travels across the United States to learn about American culture, customs, and lifestyle. Like \"Bad Grandpa,\" \"Borat\" also combines scripted scenes with unscripted interactions with real people, resulting in shocking, hilarious, and cringe-worthy moments that reveal the prejudices and peculiarities of the everyday people Borat encounters."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Translate to Russian:\nAt the left top there is the current set of tiles used.\n\nAssistant: На верху слева находится текущий набор используемых плиток. Если объяснять как для пятилетнего: представь, что это картинка с разноцветными квадратиками сверху слева. Эти квадратики - это наши плитки, которые мы сейчас используем."} +{"text": "System: You are an AI assistant, who knows every language and how to translate one language to another. Given a task, you explain in simple steps what the task is asking, any guidelines that it provides. You solve the task and show how you used the guidelines to solve the task.\n\nUser: In this task, you're given a statement, the genre to which that statement belongs, and a label indicating if the statement should be agreed with (entailment), disagreed with (contradiction), or neither (neutral). Your job is to write a sentence that describes the genre that follows the tone with respect to the statement, as indicated by the label. If sentence X agrees with sentence Y, the can be concluded from one another. If sentence X disagrees with sentence Y, they can not be correct at the same time. The sentence must also belong to the genre specified.\nQ: Statement: He'd been unconscious long enough for them to have gathered any amount they wanted.\nLabel: entailment.\nGenre: fiction.\nA:\n\nAssistant: As a result of his prolonged unconsciousness, they had plenty of time to collect whatever they desired."} +{"text": "User: Jesse had to get everyone's coffee order. Jesse left that coffee shop with the orders. Given that the answer to a question is \"know what everyone wants\", what is the question?\n\nAssistant: What does Jesse need to do in order to successfully complete his task at the coffee shop?"} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Question: what does the family leave act? Would \"In order to be eligible for FMLA leave, an employee must have been at the business at least 12 months, and worked at least 1,250 hours over the past 12 months, and work at a location where the company employs 50 or more employees within 75 miles.\" be a reasonable answer?\n\nAssistant: The Family and Medical Leave Act (FMLA) is a United States labor law that provides certain employees with up to 12 weeks of unpaid, job-protected leave per year for specific family and medical reasons. It was signed into law on February 5, 1993. The purpose of the FMLA is to offer employees job protection and a balance between work and personal life during certain life events that require time away from work.\nYour answer is correct in detailing the eligibility requirements for an employee to qualify for FMLA. Here is a rephrased version:\n\"In order to be eligible for FMLA leave, an employee must have worked for the company for at least 12 months, clocking a minimum of 1,250 hours within that time frame. Additionally, the employee must work at a company location with 50 or more employees within a 75-mile radius.\"\nThis accurately summarizes the main criteria an employee needs to meet to be eligible for FMLA leave."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: On what date was the 2016 studio album which had the song \"Freedom\" released?\nA:\n\nAssistant: Step 1: Identify the album containing the song \"Freedom.\"\nStep 2: Find out who the artist is for the song, and the album.\nStep 3: Search for the release date of the album.\nThe song \"Freedom\" is from the album \"Lemonade\" by Beyoncé. \nNow that we have identified the album and artist, I will find the release date for the \"Lemonade\" album.\n\"Lemonade\" by Beyoncé was released on April 23, 2016.\nSo the answer is A: April 23, 2016."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: What most naturally follows?\nHow to rush a sorority \nUnderstand the greek system. \nThe major female greek system is governed by the national panhellenic conference (npc). The npc is made up of 26 member organizations.\n\nAssistant: So, after learning about the Greek system and the National Panhellenic Conference (NPC), which has 26 member organizations, the next step is to find out which sororities are on your college campus. Then, you can attend their events and get to know the members to see if you fit in. Once you find a sorority you like, you will participate in the rushing and recruitment process, which is like a big party where the sororities and potential new members meet each other. By doing all these steps, you have a better chance of joining the sorority of your choice. Remember, it's like finding a group of friends who share the same interest as you!"} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: What is the answer to this question? who played alfred in miracle on 34th street?\n\nAssistant: The actor who played Alfred in Miracle on 34th Street is Alvin Greenman. He acted in the movie 1947 version. Alfred is a nice man who works at the same store as Santa Claus and helps him bring happiness to the people during the Christmas season."} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Question: Question: what state is jakarta in? I found the following answer on Google: The city is the seat of the ASEAN Secretariat. Is that a correct answer? Yes or no.\nAnswer:\n\nAssistant: No. Jakarta is the capital city of Indonesia and is located in the Special Capital Region of Jakarta, not a state."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Data: United States, ethnicGroup, Asian Americans; United States, language, English language; A Severed Wasp, country, United States\nSentence about the following data:\n\nAssistant: In the United States, there are many Asian Americans, which is a group of people from different Asian countries. Almost everyone in the United States speaks the English language, and one story from this country is called \"A Severed Wasp.\""} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Summarize this article in one sentence.\nIf you’re new to doing a 180 on your skateboard or haven’t done the trick in a while it’s a good idea to get your body used to being in the proper position. Practice skating around for a bit by pushing off and then situating your body into the right position. The ollie is the most fundamental skateboard trick and it’s incorporated into pretty much every trick you’ll do. And the 180 is an ollie where you turn your board and body 180-degrees. For the 180 you’ll be turning around either behind you for a backside 180, or in front of you for a frontside 180. Once you have the ollie down, you’re well on your way to doing a 180 skateboard trick! If you ride “regular” you will have your right foot in the back positioned on the tail of the board. If you ride “goofy” then your left foot will be in back on the tail. You want to position your back foot on the tip of the tail. To do a 180 on the skateboard, you’ll want to position your back foot on the corner of the tail. For a straight ollie your front foot will be centered about halfway up the board or slightly back. You also want to put more pressure on the outer edge of your foot (the side your pinky toe is on). To help you with turning, plant your front foot so its slightly hanging over the opposite edge of board from your back foot. This may feel awkward but it will help to keep your board balanced when you’re in the air. Your shoulders and your upper body is where most of the work will come from when you perform your 180. By setting your shoulders correctly you will be able to turn your body quickly enough and this should allow your legs and the board to follow. You might be tempted to try and force the board to turn using mostly your legs, but you won’t get enough torque this way. For a backside 180, position your front shoulder slightly in front of your chest and dip it down so when you pull up you have more momentum to work with. For a frontside 180, you want to position back shoulder in front of your chest and dip it down as well so you can pull up and create momentum. You don’t want too much of a bend so that you end up crouching too low as that won’t allow you to pop high enough. You should stand with relaxed knees like you’re getting ready to hop slightly. If you are bending correctly you should be able to slide your front forward up the board with ease.\n\nAssistant: This article explains the process of performing a 180 skateboard trick, emphasizing the importance of mastering the ollie, proper foot positioning, and using the upper body for turning momentum."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Briefly summarize this sentence: new zealand army engineers rebuilding iraq have been confined to their base by escalating violence in the country and prime minister helen clark said tuesday they may leave if the unrest continues .\n\nAssistant: The New Zealand Army engineers who have been working on rebuilding efforts in Iraq are currently confined to their base due to a significant increase in violence within the country. Prime Minister Helen Clark mentioned on Tuesday that if this ongoing unrest persists, it may result in the withdrawal of these engineers from Iraq."} +{"text": "System: Explain how you used the definition to come up with the answer.\n\nUser: You will be given a definition of a task first, then some input of the task.\nIn this task, you're given a four sentences of story written in natural language. Your job is to complete end part of the story by predicting appropriate last sentence which is coherent with the given sentences.\nSentence1: Regrettably, our kitty passed away two weeks ago. Sentence2: She was a young cat, and we didn't know why she died. Sentence3: We found her on the floor when we woke up in the morning. Sentence4: My husband buried her in the backyard.\nOutput:\n\nAssistant: We planted a small tree next to her grave to keep her memory alive."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Read the text and determine if the sentence is true:\nSt Estevam is an Island in Ilhas, Goa, India. St Estevam is known also 'Juvem' and in the past was known as 'Shakecho Juvo' - the isle of vegetables - known for its long, seven-ridged, light green ladyfingers. Therefore the people of Juvem came to be nicknamed 'bhende'.\nSentence: St Estevam contains many giraffes on the island.\n\nAssistant: Step 1: Read the text carefully\nThe given text describes St Estevam, an island in Ilhas, Goa, India. It mentions its alternate names as \"Juvem\" and \"Shakecho Juvo,\" which means the isle of vegetables. The text also mentions that St Estevam is known for its long, seven-ridged, light green ladyfingers and that the people of Juvem are nicknamed 'bhende'.\nStep 2: Analyze the given sentence\nThe given sentence states that St Estevam contains many giraffes on the island.\nStep 3: Compare the statement to the information provided in the text\nThe text describes the location, alternate names, the type of vegetation, and a nickname for the residents of St Estevam. However, there is no mention of the presence of giraffes on the island.\nStep 4: Determine if the sentence is true\nSince the text does not contain any information about the presence of giraffes on St Estevam, we cannot determine the given sentence to be true based on the information provided."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Attributes: name = The Phoenix, food = Chinese, customer rating = low, area = riverside. Produce a detailed sentence about this restaurant.\n\nAssistant: The Phoenix, a riverside Chinese restaurant offering a diverse selection of authentic cuisine, finds itself struggling to improve its low customer rating despite its idyllic location, and seeks to enhance the dining experience for its patrons to gain better recognition within the community."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Data: Bandeja paisa, ingredient, Lemon; Lemon, family, Rutaceae\nSentence about the following data:\n\nAssistant: Bandeja Paisa is a yummy dish that uses a lot of ingredients, and one of them can be lemon. Lemons are a type of fruit that comes from a family of plants called Rutaceae. It's like how your family has a last name, fruits have families too!"} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Read the text and determine if the sentence is true:\nI say to the government, to the people of this country and certainly to members of parliament that it is our duty and responsibility to ensure that we are not led down the garden path by ideologies put forward by extremely passionate interventionists, if I can use the term used by the member for Broadview-Greenwood.\nSentence: I don't talk to the government.\n\nAssistant: To properly analyze the sentence, I will follow these steps:\n1. Read the text carefully.\n2. Break it down into key statements.\n3. Examine the relationship between the statements and the sentence in question.\n4. Determine if the sentence is true or false based on the information in the text.\nStep 1: Read the text carefully:\nI say to the government, to the people of this country and certainly to members of parliament that it is our duty and responsibility to ensure that we are not led down the garden path by ideologies put forward by extremely passionate interventionists, if I can use the term used by the member for Broadview-Greenwood.\nStep 2: Break it down into key statements:\n- The speaker is addressing the government, the people, and members of parliament.\n- They believe it is their duty and responsibility to protect against misleading ideologies.\n- They reference a term used by a member for Broadview-Greenwood.\nStep 3: Examine the relationship between the statements and the sentence in question:\nThe sentence in question is: I don't talk to the government.\nIn the text, the speaker is addressing the government alongside the people and the members of the parliament. This directly contradicts the sentence in question.\nStep 4: Determine if the sentence is true or false based on the information in the text:\nBased on the information from the text, the sentence \"I don't talk to the government\" is false, as the speaker is explicitly addressing the government in their statement."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Concepts: country, layer, move, passage, rotate\nWrite a sentence that includes all these words.\n\nAssistant: In order to move the passage's rotating layer, the country's engineers designed a mechanism that could effortlessly rotate the layer, enabling a swift and dynamic passage through the structure."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Question: Information: - Barry Levinson (born April 6, 1942) is an American filmmaker, screenwriter, and actor. Levinson's best-known works are comedy-drama and drama films such as \"Diner\" (1982), \"The Natural\" (1984), \"Good Morning, Vietnam\" (1987), \"Rain Man\" (1988), \"Bugsy\" (1991), and \"Wag the Dog\" (1997). He won the Academy Award for Best Director for his work on \"Rain Man\", which also won the Academy Award for Best Picture. - The Academy Awards, or \"Oscars\", is an annual American awards ceremony hosted by the Academy of Motion Picture Arts and Sciences (AMPAS) to recognize excellence in cinematic achievements in the United States film industry as assessed by the Academy's voting membership. The various category winners are awarded a copy of a statuette, officially called the Academy Award of Merit, which has become commonly known by its nickname \"Oscar.\" The awards, first presented in 1929 at the Hollywood Roosevelt Hotel, are overseen by AMPAS. - Thomas Cruise Mapother IV (born July 3, 1962), known professionally as Tom Cruise, is an American actor and producer. He has been nominated for three Academy Awards and has won three Golden Globe Awards. He started his career at age 19 in the film \"Endless Love\" (1981). After portraying supporting roles in \"Taps\" (1981) and \"The Outsiders\" (1983), his first leading role was in the romantic comedy \"Risky Business\", released in August 1983. Cruise became a full-fledged movie star after starring as Pete \"Maverick\" Mitchell in the action drama \"Top Gun\" (1986). One of the biggest movie stars in Hollywood, Cruise starred in several more successful films in the 1980s, including the dramas \"The Color of Money\" (1986), \"Cocktail\" (1988), \"Rain Man\" (1988), and \"Born on the Fourth of July\" (1989). - Rain Man is a 1988 American road comedy-drama film directed by Barry Levinson and written by Barry Morrow and Ronald Bass. It tells the story of an abrasive and selfish young wheeler-dealer, Charlie Babbitt (Tom Cruise), who discovers that his estranged father has died and bequeathed all of his multimillion-dollar estate to his other son, Raymond (Dustin Hoffman), an autistic savant, of whose existence Charlie was unaware. Charlie is left with only his father's car and collection of rose bushes. In addition to the two leads, Valeria Golino stars as Charlie's girlfriend, Susanna. - Mark Johnson ( born December 27 , 1945 ) is an American film producer . Johnson won the Best Picture Academy Award for producing the 1988 drama movie Rain Man , starring Dustin Hoffman and Tom Cruise . The film , winner of four Oscars , also captured a Golden Globe as Best Picture . After reading the paragraphs above, we are interested in knowing the entity with which 'mark johnson ' exhibits the relationship of 'nominated for'. Find the answer from the choices below. Choices: - academy award for best director - academy award for best picture - academy awards\nAnswer:\n\nAssistant: Mark Johnson exhibits the relationship of 'nominated for' with the Academy Award for Best Picture."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: What do people do when they are in a fight? Choose the most suitable option to answer the above question. Options: - facing each other - get angry - kiss - gather together - learn from each other\nAnswer:\n\nAssistant: Answer: get angry"} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Here is a goal: what goes into homemade curry soup?\nWhich way makes more sense to accomplish this goal?\nOptions are:\ni. 1 ounce dry rice sticks (pad Thai noodles), to yield about 3/4 cup softened noodles 2-3 cups boiling water, plus 3/4 cup for serving 1/4 cup kale ribbons 1/4 cup shredded carrots 1/2 cup cooked tofu cubes 1 green onion, thinly sliced 2 teaspoons Thai curry paste (see notes) 2 teaspoons tamari or soy sauce 2 tablespoons coconut shellls\nii. 1 ounce dry rice sticks (pad Thai noodles), to yield about 3/4 cup softened noodles 2-3 cups boiling water, plus 3/4 cup for serving 1/4 cup kale ribbons 1/4 cup shredded carrots 1/2 cup cooked tofu cubes 1 green onion, thinly sliced 2 teaspoons Thai curry paste (see notes) 2 teaspoons tamari or soy sauce 2 tablespoons coconut cream\n\nAssistant: Imagine we are making a delicious curry soup for a superhero! We need to choose the best option with the right ingredients.\nThe second option (ii) makes more sense to accomplish this goal. \nIt has noodles, boiling water, kale, carrots, tofu, green onion, Thai curry paste, tamari or soy sauce, and coconut cream. The coconut cream will make the curry soup creamy and tasty, just like a superhero would want.\nThe first option (i) has coconut shells, and those are hard and not yummy. A superhero wouldn't want to eat those in their curry soup. So, let's go with option (ii) for the perfect homemade curry soup!"} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Here is some data about a restaurant: name = Fitzbillies, eatType = coffee shop, food = Indian, priceRange = cheap, area = city centre. Write a sentence that includes the following data about a restaurant\n\nAssistant: Fitzbillies is a coffee shop in the city centre that serves yummy Indian food at cheap prices."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Choose your answer from:\nI. No;\nII. Yes;\nGiven those answer options, answer the question: Based on this review, would the user recommend this product? === Review: I can't stop listening to this CD! I've always listened to all Tim's cd's but since buying Tim McGraw and the DHD's that's all I'm listening to over and over. I love it and it's different than anything he's ever done before. My favorites are SLEEP TONIGHT (I love the chorus on that), ALL WE EVER FIND, WATCH THE WIND BLOW BY, and I KNOW HOW TO LOVE YOU WELL but I love every song on the cd. Tim's voice is perfect and the music is great with the sound of the fiddle, guitars, drums and keyboards standing out on different songs. They compliment each other so well. Answer:\nA:\n\nAssistant: II. Yes"} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Detailed Instructions: In this task, you are given two statements. The task is to output whether a given textual premise, i.e. Statement 2, entails or implies a given scientific fact, i.e. Statement 1. The output should be 'entails' if Statement 2 supports Statement 1 and should be 'neutral' otherwise.\nQ: Sentence 1: The sun is directly over the equator during. Sentence 2: During equinox times the Sun's vertical ray is shining directly on the Earth's equator and neither hemisphere is tilted toward or away from the Sun.\nA:\n\nAssistant: entails"} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Attributes: name = The Punter, priceRange = less than £20, area = riverside, familyFriendly = yes, near = The Portland Arms. Produce a detailed sentence about this restaurant.\n\nAssistant: Step 1: Determine the key elements from the attributes\n- Name: The Punter\n- Price Range: less than £20\n- Area: riverside\n- Family Friendly: yes\n- Near: The Portland Arms\nStep 2: Formulate a sentence structure\n- Begin with the name of the restaurant\n- Incorporate the area and what it's near\n- Mention the price range\n- End with the family-friendly attribute\nStep 3: Combine the elements to form a detailed sentence\n\"The Punter, a family-friendly restaurant, is nestled by the riverside, just a short walk from The Portland Arms, offering a delightful dining experience for less than £20 per person.\""} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Choose from the options at the end. Are the following two questions the same?\nWhat does Coke's slogan 'It's the Real Thing' mean?\nWhat are Coke and Pepsi's market share in the U.S.? In other countries?\nOPT:\n a). no\n b). yes\nThe answer is:\n\nAssistant: a). no\nThese two questions are not the same, sweetie. The first question asks about the meaning of Coke's slogan, while the second one asks about how much of the market Coke and Pepsi have in the U.S. and other countries. So they're asking about different things!"} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Choose your answer. Betrusted and TruSecure merge\nThe merged company CyberTrust will service 4,000 clients worldwide and says it expects to have annual revenues of USD160 million. The firm will also employ 1,000 people globally, including 49 in Dublin in its identity management division.\nWhich topic is this article about?\nChoose your answer from:\n[a]. World;\n[b]. Sports;\n[c]. Business;\n[d]. Science/Tech;\n\nAssistant: This article is about topic [c]. Business.\nIt discusses the merger between two cybersecurity companies, Betrusted and TruSecure. The merged company will be named CyberTrust and is expected to have a global reach, servicing 4,000 clients worldwide. The company anticipates annual revenues of around USD 160 million and will employ 1,000 people, including 49 in Dublin, Ireland, in its identity management division.\nThis topic falls under the domain of business as it involves companies, their financial prospects, and market operations. Mergers and acquisitions are common business strategies employed by companies to expand their market presence, enhance their offerings and services, and become more competitive in their respective industries. In this case, the article focuses on the cybersecurity industry, highlighting the business implications of the merger.\nCyberTrust, as a newly merged company, aims to strengthen its market position by combining the resources, expertise, and client base of both Betrusted and TruSecure. The company will provide services such as identity management, ensuring that its global clientele receives comprehensive security solutions to protect their digital assets and mitigate cybersecurity threats effectively.\nMoreover, the article also touches on the employment aspect, which is another key component of business operations, as companies often create job opportunities through mergers and acquisitions. This can lead to economic growth and further solidify a company's presence in a particular region or industry. In this instance, CyberTrust will employ 1,000 people globally, contributing to the international workforce specializing in cybersecurity services.\nIn conclusion, this article covers the business domain of the merger between Betrusted and TruSecure, resulting in the formation of CyberTrust, a leading cybersecurity company providing identity management and other security solutions to clients across the world."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: This is some data: Life Liquid FEATURED_GUEST(S) Journalist; Life Liquid PRODUCER(S) Juju.\nGenerate a detailed description of this data.\n\nAssistant: Step 1: Identifying the key elements of the data\nIn the given data, there are four key elements:\n1. Life Liquid\n2. FEATURED_GUEST(S)\n3. Journalist\n4. PRODUCER(S) Juju\nStep 2: Understanding the context of the data\nIn order to generate a detailed description of the data, first we need to understand the context. It seems like Life Liquid is some kind of production, possibly a show or an interview, featuring a guest (the journalist). Juju seems to be the producer of this production.\nStep 3: Describing each element in detail\n1. Life Liquid: This is possibly the title or theme of the show or interview. It could be an indication of the content or the main topic of the interview. Life Liquid could be related to lifestyle, wellness, or a product being promoted.\n2. FEATURED_GUEST(S): The term indicates that there might be one or more guests invited to participate in the event. In this case, we only have a journalist as the featured guest.\n3. Journalist: The featured guest of Life Liquid is a journalist, which could indicate that the content of the show/interview is centered around journalism, current events, or an in-depth analysis of a certain topic or issue.\n4. PRODUCER(S) Juju: This piece of information reveals that the individual or entity responsible for organizing, planning, and overseeing the production of Life Liquid is a person or company named Juju.\nStep 4: Integrating the information into a detailed description\nThe data refers to a production titled \"Life Liquid\", featuring a journalist as its guest, and is organized and supervised by the producer(s) Juju. Life Liquid could be a show or an interview discussing journalism, current events, or a specific topic related to the guest's expertise. This production likely includes conversations and insights on the chosen subject matter, with the journalist sharing their experiences and perspectives while engaging with the host and audience. Juju, as the producer, ensures that the show or interview runs smoothly and is successfully executed according to its vision and purpose."} +{"text": "User: Use information from the paragraph to answer the question. Question: If we change resistors in series to be in parallel the resistance increases or decreases? Paragraph : When we add resistors in parallel to each other the total resistance decreases and the current increases.\nThe answer to this question is:\n\nAssistant: decreases"} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Here's a question: Who will be voted as the most likely to be successful? Here are possible answers to this question: - bitterest - representative - cleverest - bunny rabbit - corrupt I believe the correct choice is \"cleverest\", here's why:\nThe answer to this question is:\n\nAssistant: The most likely to be successful would be the \"cleverest\" individual. This is because cleverness indicates intelligence, resourcefulness, and problem-solving abilities, all of which are traits that are generally associated with success. The other options - bitterest, representative, bunny rabbit, and corrupt - do not necessarily possess or highlight these characteristics that contribute to success. Therefore, the correct choice is indeed \"cleverest.\""} +{"text": "System: You are an AI assistant, who knows every language and how to translate one language to another. Given a task, you explain in simple steps what the task is asking, any guidelines that it provides. You solve the task and show how you used the guidelines to solve the task.\n\nUser: In this task, you are given a text of the article. Your task is to generate a headline (title) for this article.\nQ: tom UNK , a pop artist best known for his modern take on the reclining female nude , has died at age ## .\nA:\n\nAssistant: Task: Generate a headline (title) for an article.\nArticle: Tom UNK, a pop artist best known for his modern take on the reclining female nude, has died at age ##.\nGuidelines:\n1. Summarize the main idea of the article.\n2. Keep the headline short and attention-grabbing.\n3. Use keywords from the article.\nSteps:\n1. Identify the main idea: Tom UNK, a pop artist known for his modern take on the reclining female nude, has passed away.\n2. Select attention-grabbing keywords: Tom UNK, pop artist, reclining female nude, died\n3. Create a short and engaging headline using the chosen keywords.\nSolution:\n\"Pop Artist Tom UNK, Renowned for Modern Reclining Female Nudes, Dies at ##\""} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Before you begin the process of haltering your horse, you'll need to prepare the halter. If your horse is acting up, it can be hard to prepare the halter in the moment. Unbuckle or unsnap the crown piece. This is the strap of the halter that goes behind the horse's ears. It should have a buckle or snap somewhere on it. Have your halter clipped to your lead line ahead of time. You want to be able to halter your horse and immediately lead him where he needs to go. You clip the lead line onto a ring at the bottom of the noseband. The noseband is the portion of a halter that slips around a horse's nose. Call your horse as you approach him. This helps alert your horse to your presence. This way, he's less likely to get spooked by your approach. Call out his name or call out something like, \"I'm coming in!\" to let him know you're entering his pasture or stall. Make sure you have your halter in your left hand as you approach. You will be haltering the horse from the left side. You want to make sure your horse knows what's happening. If you try to deceive the horse by hiding the halter, he'll learn to be distrustful of your intent. This can lead to a bad relationship between you and the horse. When haltering a horse, you should always approach from the left side. Halters are designed to be buckled on the left. You should approach your horse so you're standing near his head on the left side of his body. You do not want to put the halter on instantly. This can startle your horse. When you are a foot or so away from your horse's head, allow him to sniff your hand. This way, your horse will also familiarize himself with your scent. He'll be able to differentiate your smell from the smell of potential predators in other situations down the road. A nervous horse may benefit from having treats as a reward for staying calm.\nSummary:\n\nAssistant: 1. Prepare the halter by unbuckling or unsnapping the crown piece.\n2. Attach the lead line to the ring at the bottom of the noseband.\n3. Call your horse as you approach, to alert him of your presence.\n4. Hold the halter in your left hand while approaching the horse.\n5. Approach the horse from the left side, as halters are designed to be buckled on this side.\n6. Allow your horse to sniff your hand before haltering, so he becomes familiar with your scent.\n7. Offer treats to a nervous horse as a reward for staying calm during the haltering process."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: I know that the answer to the question \"When did the first Crusade end?\" is in \"The crusades were intended to seize Jerusalem from Muslim control. The First Crusade was proclaimed by Pope Urban II (pope 1088–99) at the Council of Clermont in 1095 in response to a request from the Byzantine Emperor Alexios I Komnenos (r. 1081–1118) for aid against further Muslim advances. Urban promised indulgence to anyone who took part. Tens of thousands of people from all levels of society mobilised across Europe and captured Jerusalem in 1099. One feature of the crusades was the pogroms against local Jews that often took place as the crusaders left their countries for the East. These were especially brutal during the First Crusade, when the Jewish communities in Cologne, Mainz, and Worms were destroyed, and other communities in cities between the rivers Seine and Rhine suffered destruction. Another outgrowth of the crusades was the foundation of a new type of monastic order, the military orders of the Templars and Hospitallers, which fused monastic life with military service.\". Can you tell me what it is?\nAnswer:\n\nAssistant: The First Crusade ended in 1099, when the crusaders successfully captured Jerusalem from Muslim control."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Solve -47 = -6*c - 71 for c.\n\nAssistant: To solve the equation -47 = -6c - 71 for c, follow these steps:\nStep 1: Add 71 to both sides.\nOur goal is to isolate c on one side of the equation. To cancel out the -71 that is being subtracted from -6c, we'll add 71 to both sides of the equation.\n-47 + 71 = -6c - 71 + 71\nStep 2: Simplify both sides.\nNow, simplify by doing the math on both sides of the equation:\n24 = -6c\nStep 3: Divide by the coefficient of c.\nWe want to isolate c by itself, so we'll divide both sides of the equation by the coefficient of c, which is -6:\n24 / -6 = -6c / -6\nStep 4: Simplify to solve for c.\nNow we can simplify the equation to find the value of c:\n-4 = c\nSo the solution to the equation -47 = -6c - 71 is c = -4."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Premise: Man with a large mohawk sweeping with a broom.\nHypothesis: A man is holding a broom.\n.Given the premise, can we conclude the hypothesis?\n\nAssistant: Step 1: Identify the components of the premise\n- Man with a large mohawk\n- Sweeping with a broom\nStep 2: Analyze the premise\nIn the given premise, there is a man who has a large mohawk hairstyle and is sweeping using a broom.\nStep 3: Compare the premise to the hypothesis\nThe hypothesis states that a man is holding a broom.\nStep 4: Justify the conclusion\nWe can conclude that the hypothesis is true given the premise. The man in the premise is described as having a large mohawk and sweeping with a broom, implying that he is holding a broom in order to sweep."} +{"text": "System: Explain how you used the definition to come up with the answer.\n\nUser: In this task, you will be presented with a passage and have to answer a question based on your understanding of the events inferred from the passage. Among the entities, try to find the best entity that is most likely to fill in \"_\" and classify the answers based on options.\nQ: (CNN)In one month Nigerians will vote in what may be the country's most competitive and contentious elections yet. But this important process will take place in the shadow of a worsening threat -- Boko Haram, the Islamist insurgent group otherwise known as Jama'atu Ahlis Sunna Lidda'awati wal-Jihad, which killed as many as 2,000 people in the northeastern town of Baga on January 3. The attacks of the past few weeks -- Boko Haram's biggest ever in terms of casualties on Baga, and on markets using children to detonate devices -- have drawn international focus back to the crisis. But such appalling events are not new -- they are an intensification, a next step in a crisis which has steadily worsened since Boko Haram's 2010 reemergence. And they raise a number of questions about the trajectory of the insurgency and the response to it.Boko Haram massacred as many as 2,000 people in Baga, Nigeria on January 3Donnelly: Taking Baga helps group consolidate control in northeast NigeriaDonnelly: Group's tactics are unclear; Boko Haram kills because it can \n Questions:There still isn't a clear picture of what happened at _ and who did the killing -- the age, the motivation, the state of mind of what may have been militants, conscripts or armed and terrified child combatants. (A) CNN (B) Nigerians (C) Boko Haram (D) Islamist (E) Jama'atu Ahlis Sunna Lidda'awati wal-Jihad (F) Baga (G) Nigeria (H) Donnelly (I) Group\nA:\n\nAssistant: There still isn't a clear picture of what happened at (F) Baga and who did the killing -- the age, the motivation, the state of mind of what may have been militants, conscripts or armed and terrified child combatants.\nI used the definition given in the passage, which mentioned that Boko Haram killed as many as 2,000 people in the northeastern town of Baga on January 3. This information helps in determining that the event in question occurred in Baga; thus, (F) Baga is the correct answer."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Q: What method did Ibn Sina introduce?\n A: He pioneered the science of experimental medicine and was the first physician to conduct clinical trials.\n Does the answer correctly answer the question\nChoices:\n(1). yes.\n(2). no.\n\nAssistant: (1). yes."} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Detailed Instructions: In this task you will be given a list, of lists, of integers. For every inner list contained in the input list, you should multiply every odd number in that list. The output should be a list of integers with the same length as the number of lists in the input list. If there are no odd numbers in an inner list you should output 0 for that list.\nQ: [[-27, 36, 24], [-35, -28, -9], [46, -36, 28], [15, 18], [-9, 4, 33]]\nA:\n\nAssistant: [27, 315, 0, 15, -297]"} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Name Charles Dickens' first novel, featuring the characters Nathaniel Winkle, Augustus Snodgrass and 'the fat boy' Joe?\nWhat is the answer?\n\nAssistant: To find the answer, I will recall the list of Charles Dickens' novels and think about which one includes the mentioned characters.\nUpon reviewing his novels, I recognize that the characters Nathaniel Winkle, Augustus Snodgrass, and Joe 'the fat boy' are from his first novel, \"The Pickwick Papers.\" This book was also published as \"The Posthumous Papers of the Pickwick Club.\"\nSo, the answer is The Pickwick Papers."} +{"text": "System: You are an AI assistant. You should describe the task and explain your answer. While answering a multiple choice question, first output the correct answer(s). Then explain why other answers are wrong. You might need to use additional knowledge to answer the question.\n\nUser: Here is one sentence: You make me feel naked.\nHere is another sentence: She felt small and insignificant.\nQ: Does the term feel mean the same thing in both these sentences?\nChoose from:\n[i] different meanings\n[ii] the same meaning\n\nAssistant: [i] different meanings\nIn the first sentence, \"You make me feel naked,\" the term \"feel\" is used to express an emotional or psychological state, suggesting a feeling of vulnerability or exposure. In this context, feel refers to the subjective experience or perception of the speaker.\nIn the second sentence, \"She felt small and insignificant,\" the term \"feel\" is also used to describe an emotional state, but it conveys a different meaning than the first sentence. Here, the speaker is describing a feeling of being unimportant or powerless. While both sentences use the term \"feel\" to describe an emotional state, the emotional states being described are different, making the meaning of the term \"feel\" different in each sentence."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Q:Article: Mr. Whitcomb Judson was a really clever man. He invented a special train that was pushed along by air and in 1893 he invented something like a zip that could be used to open and close clothes. It could be used on men's trousers and on ladies' dresses. He did not call his invention a 'zip' but everybody says that he was the 'Father of the Zip'. Mr. Judson tried to sell his invention but no one wanted it because it did not work very well. A Canadian changed Mr. Judson's invention and made it better. He made a zip that worked very well in 1913.This is the same zip we use today. Question: Mr. Whitcomb Judson invented _ in 1893. Yes or no, is the answer \"something like a zip\"?\nA:\n\nAssistant: Yes, the answer is \"something like a zip.\""} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: News article: Tobacco, fast food in easier reach of minority and low-income students\nCompared to schools with mostly white students, schools with higher numbers of Hispanic, black, and low-income students are more likely to have tobacco retailers and fast-food restaurants in easy reach, researchers say.\n\"Schools are places where we expect to be healthy, supportive environments for kids, but right outside the door (students) could be exposed to unhealthy influences,\" lead author Heather D'Angelo of Westat in Rockville, Maryland told Reuters Health in a phone interview.\n\"I was surprised that there were so many tobacco outlets near schools,\" said D'Angelo, who at the time of the study was at the University of North Carolina in Chapel Hill.\nD'Angelo and her colleagues randomly selected more than 18,000 public schools in 97 U.S. counties, then analyzed their proximity to retailers within an 800-meter radius - roughly equivalent to a 10-minute walk.\nThey report in the American Journal of Public Health that 40 percent of the schools were near at least one fast food restaurant, 77 percent were close to at least one retailer likely to sell tobacco products, and 38 percent of schools were close to both.\nSchools with higher numbers of Hispanic, black, and low-income students receiving free and reduced-price lunch were more likely to have tobacco retailers and fast-food restaurants within a ten-minute walk.\nMore than 50 percent of schools with a majority of Hispanic students were within 800 meters of both a tobacco outlet and a fast food outlet, compared to 21 percent of schools with a majority of white students.\nWhen researchers took population density into account, they found that every 10 percent increase in the proportion of a school's Hispanic students was associated with a five percent higher likelihood the school was within walking distance of places to buy both tobacco and fast food.\nA 10 percent increase in low-income students was associated with a three percent higher likelihood a school was near both kinds of stores.\nThe researchers also report that fast food and tobacco were more easily accessible for older kids, suggesting \"that retailers may be well aware of and ready to take advantage of the autonomy and purchasing power of older youths.\"\n\"High schools had 40 percent more fast food restaurants and 24 percent more tobacco outlets within 800 meters than did primary schools and high schools had nearly 1.5 times greater odds of having both\" nearby, they wrote.\nThe researchers say theirs is the first study to consider the availability of both tobacco and fast food on the basis of students' socioeconomic and racial characteristics.\nThe study did not analyze how students get to school, whether they leave campus during lunchtime, or where they live. Nor did the researchers have data on students' eating habits, or their weight or health status.\n\"This is looking at one point in time,\" said D'Angelo. \"I can't say one thing causes another. However, there's still an association here.\"\nShe noted that since the study only looked at fast food chains, the number of restaurants serving unhealthy food in the areas studied could be even higher.\nD'Angelo and her co-authors argue that while licensing or zoning policies could restrict students' exposure to tobacco and unhealthy foods, more research on neighborhood disparities is needed in order to change public policy.\n\"Food and tobacco habits begin at a young age, and habits tend to consolidate and 'stick' or 'track' over time,\" said Steven H. Kelder, an epidemiologist who was not involved with the study.\nKelder, who co-directs the Michael and Susan Dell Center for Healthy Living at the University of Texas School of Public Health in Austin, told Reuters Health by email that studies like this one provide information on how to address growing health disparities based on race and income and could point to the need for tighter restrictions at the point of sale, or counter-marketing to students of color.\n\"In this country we have a history of commercially taking advantage of poor, less educated, and vulnerable populations,\" Kelder said. \"Documenting health disparities highlights the need to address uncovered problems.\"\nWhat are the most important parts of this news article?\n\nAssistant: 1. Schools with higher numbers of Hispanic, black, and low-income students are more likely to have tobacco retailers and fast-food restaurants nearby, according to researchers.\n2. More than 50% of schools with a majority of Hispanic students were within 800 meters of both a tobacco outlet and a fast-food outlet, while 21% of schools with a majority of white students were within the same proximity of both.\n3. The study highlights the increased accessibility of tobacco and fast food for older students and suggests that retailers may be taking advantage of their purchasing power.\n4. The researchers argue that licensing or zoning policies could restrict students' exposure to tobacco and unhealthy foods, but more research on neighborhood disparities is needed to change public policy.\n5. The study demonstrates potential health disparities based on race and income, which could lead to tighter point-of-sale restrictions or counter-marketing strategies targeting students of color."} +{"text": "System: Given a definition of a task and a sample input, break the definition into small parts.\nEach of those parts will have some instruction. Explain their meaning by showing an example that meets the criteria in the instruction. Use the following format:\nPart # : a key part of the definition.\nUsage: Sample response that meets the criteria from the key part. Explain why you think it meets the criteria.\n\nUser: Detailed Instructions: In this task, you are given a sentence which is either in the Swedish language or English language. You task is to identify the language of input sentence. Input sentence can be in Swedish or English language only and also it cannot have two languages at a time.\nProblem:Konsumenterna ser hur deras levnadsstandard sjunker, hur fattigdomen breder ut sig och hur den offentliga sektorn och den produktiva basen i de flesta av unionens länder skärs ned och upplöses till förmån för en ohämmad och fördärvbringande konkurrens, den totala marknadsekonomin och storkapitalets monopolintressen.\nSolution:\n\nAssistant: Part #1: Identify the given sentence\nUsage: The given sentence is \"Konsumenterna ser hur deras levnadsstandard sjunker, hur fattigdomen breder ut sig och hur den offentliga sektorn och den produktiva basen i de flesta av unionens länder skärs ned och upplöses till förmån för en ohämmad och fördärvbringande konkurrens, den totala marknadsekonomin och storkapitalets monopolintressen.\" This sentence is the input we have to work with.\nPart #2: Recognize that the input can only be in Swedish or English\nUsage: It is stated that the input sentence can be in Swedish or English language, and it cannot have both languages at the same time. The sentence has to belong to one of these two languages.\nPart #3: Identify the language of the input sentence\nUsage: In this case, words like \"Konsumenterna,\" \"levnadsstandard,\" and \"fattigdomen\" are not in English. Those words, as well as the overall sentence structure, suggest that this sentence is in Swedish. Therefore, the language of the input sentence is Swedish."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: She's too busy to read my blabs. \nWhat is the sentiment of this tweet?\n\nAssistant: The sentiment of this tweet seems to be a little unhappy or disappointed. The person might feel a bit sad because the other person is too busy to read their messages."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Please answer the following question: Build a movie plot around this: Who is an excellent shot? Andre\nA:\n\nAssistant: Title: \"Andre: Marksman's Mission\"\nGenre: Action, Thriller\nPlot:\nAndre Thompson, a young orphan with no known family, grew up in the dangerous streets of Chicago. Despite a troubled past and hardships, his natural talent for shooting has never gone unnoticed. Playing basketball at the nearby community center, he encounters retired Marine sniper, Jack Robinson. Jack sees potential in Andre's impeccable aim and accuracy, even when under pressure. He decides to take Andre under his wing and train him to become an elite marksman, using his skills for a greater cause.\nYears later, as an adult, Andre has become an excellent shot, thanks to Jack's mentorship. When a local terrorist organization initiates a plan to release a deadly virus, the government secretly forms a small team of highly skilled individuals to stop them. Andre is recruited as their ace marksman.\nThe team embarks on a thrilling adventure, hunting down the terrorists responsible for plotting the attack. Along the way, Andre discovers his teammates have their own hidden talents and motivations, leading to internal conflict and distrust within the group. As they get closer to the terrorist leaders, Andre's skills with a firearm are put to the ultimate test, with countless lives hanging in the balance.\nIn the movie's climax, Andre is forced to make a difficult choice between the mission and saving Jack's life, who has been kidnapped by the terrorists as leverage. Andre's decision will shape not only his own destiny, but also that of his close friends and the entire city of Chicago.\nThrough a series of events filled with intense action sequences, suspense, and emotional drama, \"Andre: Marksman's Mission\" unveils the true depth of Andre's character while showcasing his incredible skill as an elite marksman. As the story unfolds, it remarks on themes of friendship, loyalty, redemption, and the power of choice, ultimately challenging the protagonist and the audience to reflect on their values and beliefs."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: The Copenhagen Consensus Center is a US non-profit think tank, founded and headed by Bjørn Lomborg. The Center organizes the Copenhagen Consensus, a conference of prominent economists held every four years, where potential solutions to global issues are examined and prioritized using cost-benefit analysis.\nBased on the paragraph above can we conclude that \"The Copenhagen Consensus used to meet every two years.\"?\n\nAssistant: Step 1: Read the given information and identify key data points.\nThe key data points in the paragraph are:\n- The Copenhagen Consensus Center is a US non-profit think tank.\n- Founded and headed by Bjørn Lomborg.\n- The Center organizes the Copenhagen Consensus conference.\n- The conference is held every four years.\n- Economists attend the conference and prioritize solutions using cost-benefit analysis.\nStep 2: Analyze the question being asked.\nThe question is asking whether we can conclude that \"The Copenhagen Consensus used to meet every two years.\"\nStep 3: Identify the data relevant to the question.\nThe relevant data for this question is the frequency of the Copenhagen Consensus conference: held every four years.\nStep 4: Compare the data to the question.\nThe question asks if the conference used to meet every two years, but the provided information states that the conference is held every four years.\nStep 5: Reach a conclusion.\nBased on the information provided, we cannot conclude that \"The Copenhagen Consensus used to meet every two years\" as the paragraph states that it is held every four years."} +{"text": "User: Answer the following question: Title: 128 pages that felt like 500. Review: I read many glowing reviews of this novella in sources I implicitly trusted; Amazon.Com's review section included. The gap between promise and fulfillment has never been so large in my reading experience. The central character, who unfortunately is also the narrarator, is a whining, mawkish, hyper-sensitive narcissist who should have reflected and learned from his wife's criticisms, instead of ruminating about leaving her. Consequently, I've lost an evening from my life reading this trite missive. Does this product review convey a negative or positive sentiment?\nAnswer:\n\nAssistant: This product review conveys a negative sentiment."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Can we draw the following hypothesis from the context? \nContext:\nThe resignation came amid recent allegations by Massachusetts' attorney general that the Boston archdiocese engaged in \"an elaborate scheme\" to keep quiet the issue of child sexual abuse by priests.\nHypothesis: Priests have never abused children according to this sentence.\n\nAssistant: Step 1: Identify the key points in the context\n- Resignation took place\n- Massachusetts' attorney general's allegations\n- Boston archdiocese engaged in \"an elaborate scheme\"\n- Purpose of the scheme was to keep quiet the issue of child sexual abuse by priests\nStep 2: Evaluate the context against the provided hypothesis\nThe hypothesis states: \"Priests have never abused children according to this sentence.\"\nStep 3: Justify whether the hypothesis can be drawn from the context\nGiven the context, the hypothesis cannot be drawn as the context mentions the purpose of the \"elaborate scheme\" was to keep quiet the issue of child sexual abuse by priests. This statement implies that there were cases of child sexual abuse by priests, and efforts were made to conceal this information.\nTherefore, the hypothesis does not match the context."} +{"text": "System: Explain how you used the definition to come up with the answer.\n\nUser: Given a scientific question and its correct answer, generate supporting facts for the answer. This supporting fact can be an explanation for the answer to the given question.\nQuestion: The simplest type of battery contains what? Answer: single cell\n\nAssistant: I used the definition of a battery, which is a device that converts stored chemical energy into electrical energy through an electrochemical reaction. Knowing this definition, I deduced that the simplest type of battery would have the least amount of components, which in this case, is a single cell.\nA single cell battery contains one electrochemical cell, which consists of a cathode, an anode, and an electrolyte. The chemical reactions at the electrodes create a flow of electrical charge. The single cell battery is considered the simplest type because it only involves one electrochemical cell carrying out the conversion of chemical energy into electrical energy, whereas more complex batteries contain multiple cells connected together in series or parallel configurations to provide increased voltage or current capacity.\nIn summary, I used the definition of a battery to determine that the simplest type of battery would have the least number of components and processes involved, which led me to the answer that the simplest type of battery contains a single cell."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Which is an appropriate title for this article?\nOne of New York Attorney General Eliot Spitzer #39;s investigators was perusing a stack of subpoenaed documents in a Manhattan office last month when he let out a yelp, slammed down his coffee and sprinted down the corridor.\n\nAssistant: A good title for this article could be \"Exciting Discovery in New York Investigation!\" This title is simple and easy to understand while also hinting at the excitement the investigator felt when they found something important in the documents."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Here is some data: The Waterman eatType restaurant; The Waterman food Indian; The Waterman priceRange cheap; The Waterman area city centre; The Waterman familyFriendly yes.\nWrite a sentence that describes this data:\n\nAssistant: The Waterman is a family-friendly Indian restaurant located in the city centre, offering a selection of affordable and delicious dishes in a casual eatType setting."} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Detailed Instructions: Given a prompt and four completions, select the completion that is the most plausible in continuing or answering the prompt. This task is designed to test common sense and has various categories ranging between idioms, negated statements, polysemy, subject referencing, and quantitative reasoning. Generate either Completion A, Completion B, Completion C, or Completion D.\nProblem:Prompt: I can't believe that despite his dislike for cigarettes, \nCompletion A: he went out to smoke yesterday. \nCompletion B: we went out to the grocery store. \nCompletion C: he cried when I punched him. \nCompletion D: he still wouldn't pet the dog.\nSolution:\n\nAssistant: Completion A"} +{"text": "System: You are an AI assistant, who knows every language and how to translate one language to another. Given a task, you explain in simple steps what the task is asking, any guidelines that it provides. You solve the task and show how you used the guidelines to solve the task.\n\nUser: Teacher:You are given a conversation between two people. 'Person1:' and 'Person2:' are used to separate their respective dialogues. Your task is to classify the conversation either convey 'No emotion' or 'Happiness' by providing '1' and '0', respectively.\nTeacher: Now, understand the problem? Solve this instance: Person1: It's such a nice day outside . Let ’ s take a walk into town and do some window-shopping . \n Person2: You know I hate shopping more than anything else in this world . \n Person1: I know . But we could get some exercise in , and we could get some errands run , too . Window-shopping won ’ t cost you anything . \n Person2: What kind of errands do we need to run ? \n Person1: Well , we could go to the newsagent's to pick up the Sunday paper , go to the butcher's to pick up a chicken for a roast dinner , and to the chemist's to pick up your prescription from the doctor . \n Person2: I'll only go if we can stop at the office-license to pick up something to drink for tonight . \n Person1: Fine , we'll do that on the way home , so we don ’ t have to carry it the whole time . \n Person2: Sounds good . You have to promise not go into any of the stores in the mall though . I know how you like to try things on in the fitting rooms , especially if the stores have special offers . \n Person1: I promise . We're only going for a walk to get some exercise and to get some things done that we haven ’ t had time for . \n Person2: Ok , let's go then ! \n Person1: Do you have any money ? \n Person2: For what ? \n Person1: For our errands , and maybe a little extra for something from a vendor on the street . \n Person2: You said you were only going to window shop ! \n Person1: I agreed to not go into the mall , but I can still look at things in the market ! \n Person2: I should have known ... \nStudent:\n\nAssistant: Based on the given conversation and the task guidelines, here are the steps to solve the problem:\n1. Read the conversation between Person1 and Person2.\n2. Understand the overall tone and emotions expressed by the participants.\n3. Classify the conversation as either having 'No emotion' or 'Happiness' based on the guidelines.\n4. Assign the appropriate code ('1' for 'No emotion' and '0' for 'Happiness') to the conversation.\nNow let's apply the guidelines to solve the instance:\n1. Read the conversation: Person1 and Person2 are discussing their plans for the day, which includes going for a walk, running errands, and window-shopping.\n2. Understand the emotions: From their conversation, both participants seem to be in a good mood and expressing happiness as they agree on their plans and engage in friendly banter.\n3. Classify the conversation: Based on their dialogue, the conversation conveys 'Happiness'.\n4. Assign the code: Since the conversation conveys 'Happiness', the appropriate code is '0'.\nThe solution to this instance is '0', as the conversation conveys 'Happiness'."} +{"text": "System: User will you give you a task with some instruction. Your job is follow the instructions as faithfully as you can. While answering think step-by-step and justify your answer.\n\nUser: Instructions: You are given a sentence in Galician. Your job is to translate the Galician sentence into Portugese.\nInput: Iso foi exactamente o que fixemos con esas células.\nOutput:\n\nAssistant: That was exactly what we did with those cells.\nTranslation in Portuguese: Isso foi exatamente o que fizemos com essas células."} +{"text": "System: You are an AI assistant. You should describe the task and explain your answer. While answering a multiple choice question, first output the correct answer(s). Then explain why other answers are wrong. You might need to use additional knowledge to answer the question.\n\nUser: OPTIONS: A. Pour a few drops of apple cider vinegar into a small bowl. For this step, you will need about 1 gallon (3.8 l) of apple cider vinegar. B. Vacuum and vacuum well under the oven door. You can use the brush attachments on your vacuum cleaner, or you can use a handheld vacuuming machine. C. You may want to wipe several times to cover any important areas. Pour liquid dish soap into a plastic bucket. D. Apply a baking soda paste to the oven door. Mix baking soda into water until you have a paste.Choose from options above and answer: What most naturally follows?\nHow to clean the insides of a double pane window in your oven door \nUse a wet cloth to wipe away any loose debris. \nWipe down the top, sides, and crevices around your oven door to remove crumbs or other food particles. You can use plain water or a cleaning wipe for this step.\nAnswer:\n\nAssistant: D. Apply a baking soda paste to the oven door. Mix baking soda into water until you have a paste.\nExplanation:\nA. Pour a few drops of apple cider vinegar into a small bowl - This option does not make sense in this context as it requires an excessive amount of vinegar without a clear purpose.\nB. Vacuum and vacuum well under the oven door - Vacuuming would not be effective in cleaning the inside of a double pane window in the oven door.\nC. Pour liquid dish soap into a plastic bucket - This option does not directly address cleaning the insides of the double pane window.\nOption D is the most fitting answer as applying a baking soda paste is a common and effective method for cleaning oven doors, including their double pane windows. This paste helps to break down and remove accumulated grime and dirt within the crevices of the oven door."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Here is a review left by a customer on a product. Would you say he was satisfied or dissatisfied? Title: Escape into her world! Review: Jemima JThis beach read was a true \"thoughtless\" pleasure for me to turn the pages of. Jemima J is not the most realistic character in the world, as she looses a massive amount of weight in a short period of time, but Jemima is someone whom everyone can identify with. Jemima longs for a man, for whom she feels she can not have (Ben Williams) because of her outward appearance. Jemima is a writer in a small paper that has dreams of becoming something else. Jemima meets a man, whose outward appearance is what she dreams of, but inside may not be what she needs. Jemima has a great friend (Geraldine) who is a 100% supportive fabulous friend, but Jemima is so very envious of what Geraldine looks like. Everywoman struggles with all of these things.This writer did not seem to intend to make this a book you had to ponder though, but intended this book to be an easy read. I loved it! This book took me away from my present and let me escape into the world of Jemima!\nThe answer to this question is:\n\nAssistant: The customer was satisfied. \nMy justification for this answer is based on the sentiments expressed in the review. The customer described the book as a \"thoughtless pleasure,\" and a \"beach read,\" which are both positive terms. The customer also expressed appreciation for the book's content, including the character development and the relatability of the protagonist. Further, the last line - \"I loved it! This book took me away from my present and let me escape into the world of Jemima!\" - indicates the reviewer's enjoyment and satisfaction with the book. While there are some elements mentioned that could be perceived as negative, such as the unrealistic weight loss of the main character, they did not seem to reduce the customer's overall enjoyment of the book. Hence, the reviewer can be considered to be satisfied with the product."} +{"text": "User: Is there a negative or positive tone to this product review? === Title: so-so Review: I bought this lock because it was the cheapest I could get. It suits my purpose, but the design is shoddy and the combination lock has a fluid motion instead of setting into place when you align the numbers.The lock doesn't come with an option to set your own combination -- it comes with a preset combination and cannot be changed (so if you forget it, you're out of luck!)The chain and plastic cover look so-so. Definitely not high class.If you're looking for a CHEAP chain to lock your bike in a SAFE neighborhood, then this is your best bet. I would give it 1.5 stars, but they don't have that option here. Answer:\n\nAssistant: There is a negative tone to this product review."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Greg, I've been helping the HR team draft this memo. Cindy asked that I run it by you prior to having it sent out. Any comments? As we continue to address a number of issues facing Enron, it is critical that the company remain focused on attracting, motivating and retaining our diverse and talented workforce. We strongly believe that our workforce is world-class - the best in any industry, so we will continue to develop and reward this tremendous talent base. Because the lasting strength of our company is determined by the strength of our workforce, we are elevating the role employee-focused functions play within Enron. We are pleased to announce that effective immediately Human Resources, Employee Responsibility, and Executive and Corporate Compensation will report directly to the Office of the Chairman. David Oxley, Vice President, will continue to oversee Enron's Global HR function. The lead HR representatives for the individual business units will report to David in addition to their respective Office of the Chairman. Besides managing Recruiting, Staffing, Training, Compensation and the Performance Evaluation Process, David will also add Payroll and Labor Relations to his list of responsibilities. Cindy Olson, Executive Vice President, will have responsibility for Employee Relations and Corporate Responsibility, which includes Social Responsibility, Community Relations, Diversity, Employee Relations, Redeployment, Alumni Relations, Employee Events and Programs, Benefits, Wellness, and Worklife. Mary Joyce, Vice President, will continue to have responsibility for Executive and Corporate Compensation in addition to Global Employee Information Management and Reporting and Analysis. David, Cindy and Mary will each report to the Office of the Chairman. A detailed organizational chart is attached. Please join me in congratulating them on their new responsibilities. \nGenerate a subject line for this email.\n\nAssistant: Subject: Exciting Changes in Our HR Team and Their Roles 🌟👩‍💼👨‍💼"} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Please capitalize where necessary: it uses strong and powerful scanning algorithm that offers successful recovery of deleted and lost photos.\n\nAssistant: Step 1: Identify the sentences in the text.\nIn this case, there is only one sentence: \"it uses strong and powerful scanning algorithm that offers successful recovery of deleted and lost photos.\"\nStep 2: Determine where capitalization is necessary.\nThe main areas we need to focus on for capitalization are the beginning of sentences and the proper nouns if there are any.\nStep 3: Capitalize the beginning of the sentence\nThe sentence begins with \"it\", so we'll capitalize the \"i\" in \"it\".\nStep 4: Identify and capitalize proper nouns (none in this case)\nThere are no proper nouns in this sentence, so we can move to the final step.\nThe final result after capitalization: \"It uses strong and powerful scanning algorithm that offers successful recovery of deleted and lost photos.\""} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Why can't girls stand & pee?\nHow difficult is it for women to pee standing up?\nDo those questions have the same meaning?\nOptions are: -- no. -- yes.\n\nAssistant: No, the two questions do not have the exact same meaning, although they are closely related. The first question specifically asks why girls cannot stand and pee, while the second question explores the level of difficulty women encounter when attempting to pee standing up.\nTo address the first question, girls cannot stand and pee primarily due to their anatomical differences from males. Female urinary anatomy consists of a shorter urethra and a more internal urethral opening. The urethra, which is the tube that carries urine from the bladder to the outside of the body, is directed downward and forward in females, making it more challenging to aim the urine stream away from the body. Moreover, the female urethra has a wider opening and is closer to the vagina, which adds to the difficulty of controlling the flow and direction of the urine stream.\nNow, let's discuss the second question. It is difficult, but not impossible, for women to pee standing up. To do so, they would need to utilize specific techniques and sometimes even special equipment. A woman must spread her legs wide, use her fingers to separate the labia, and adjust her body position to achieve the desired angle for the urine to flow away from her body. For most women, this method does not come naturally and may require practice to achieve the desired outcome without causing mess or discomfort.\nThere are devices known as female urination devices (FUDs) or \"pee funnels\" available that enable women to urinate while standing. These devices are designed to fit over the female genital area and channel the urine stream away from the body. When used correctly, FUDs can be a convenient and hygienic solution for women who prefer or need to urinate while standing up. Some women opt to use them for outdoor activities such as hiking, camping, or in situations where cleanliness of public restrooms is a concern. These devices could make it easier for women to pee standing up, although using them may still require practice and some level of dexterity.\nIn summary, the two questions presented - the inability of girls to stand and pee and the challenges faced by women in doing so – are closely related but not exactly the same. While anatomical differences make it inherently more difficult for women to urinate standing up, it is not impossible for them to do so, especially if they utilize specific techniques or opt for female urination devices."} +{"text": "System: You are an AI assistant. You should describe the task and explain your answer. While answering a multiple choice question, first output the correct answer(s). Then explain why other answers are wrong. You might need to use additional knowledge to answer the question.\n\nUser: Fact: the Earth revolves around the sun\nQuestion: The earth revolves around\nWhat's the answer? Possible answers:\n (A). a heat source.\n (B). the Milky Way.\n (C). a neighboring planet.\n (D). the moon.\n\nAssistant: (A). a heat source.\nExplanation: The fact given states that the Earth revolves around the sun. Since the sun is a heat source, option (A) is correct. The other options are incorrect for the following reasons:\n(B) While the Earth is part of the Milky Way galaxy, it individually revolves around the sun, not directly around the entire galaxy.\n(C) The Earth doesn't revolve around a neighboring planet. Instead, both Earth and other planets in our solar system revolve around the sun.\n(D) The moon is the one that revolves around Earth, not the other way around."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Please answer the following question: Context: Carson dug his nails into Jesse's chest when Jesse started swinging at him. Question: What will Carson want to do next? Which one of these answers best answers the question according to the context? A: laugh at Jesse B: subdue Jesse C: shout in pain\nAnswer:\n\nAssistant: Answer: B: subdue Jesse\nAccording to the context provided, Carson dug his nails into Jesse's chest as a defensive response to Jesse starting to swing at him. This indicates that Carson is likely attempting to protect himself from the attack. In this situation, the best course of action for Carson would be to subdue Jesse in order to prevent further aggression and harm. So, answer B: subdue Jesse best aligns with the context given."} +{"text": "User: Extract the answer to the following question from the movie plot. If the question isn't answerable, please output \"Can't answer\". Question: Where was Captain Benson sent? Title: Saturn 3 Movie plot: In the distant future, an overcrowded Earth relies on research conducted by scientists in remote stations across the solar system. One of these, a small experimental hydroponics research station on Saturn's third moon, is run solely by Adam (Douglas) and his colleague and lover Alex (Fawcett). Adam, the younger Alex and their dog Sally enjoy their isolation, far from an overcrowded and troubled Earth. The couple have been on Saturn 3 for three years, but Alex has spent all her life in space. Alex and Adam's idyll is broken by the arrival of Captain Benson (Keitel), who has been sent to Saturn 3 as part of a project to replace at least one of the moon's scientists with a robot. Benson assembles the robot and names him Hector. One of the first of its kind, a \"Demigod Series\", Hector relies on \"pure brain tissue\" extracted from human fetuses and programmed using a direct link to Benson's brain. Benson states that it will render one of the crew \"obsolete\", most likely Major Adam who tells his partner that he is close to \"abort time\" according to the government on Earth. Unknown to both Alex and Adam, Benson is a homicidal sociopath who murdered the captain originally assigned to Saturn 3 and took his place on the mission. Benson had actually failed a crucial test of psychological stability. As Benson uses the link to his brain to program the robot, Hector acquires Benson's homicidal nature and his lust for Alex. Hector initially kills Sally the dog and then assaults Benson, having learned about Benson being a murderer during the reprogramming process. Adam and Alex manage to disable the robot while it is recharging. Believing the danger over, Adam accuses Benson of gross incompetence and orders him to dismantle the robot and return to Earth when an eclipse ends. Benson obliges, but Hector is soon reassembled by the base's older robots. Resuscitated, Hector murders Benson while he is dragging Alex to his ship. Hector then blows up Benson's spacecraft before Adam and Alex can escape in it, trapping them all on...\nAnswer:\n\nAssistant: Saturn 3"} +{"text": "User: Title: Ditto Review: Ditto on the hate the Makita batteries! I was looking into buying some batteries but had a bad taste of how long they last and have been debating on whether I should make a switch, they do have some great tools. But the batteries make it so that you don't want to buy anymore tools. Wake up Makita!! Is the review positive or negative?\nA:\n\nAssistant: Negative"} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Answer the following question: Use information from the paragraph to answer the question. Question: If Jim increased the speed of water flowing out of his bathroom tap, what would happen to the amount of energy being transferred by the water would increase or would decrease? Paragraph : Faster-moving water has more energy.\nAnswer:\n\nAssistant: Step 1: Identify the key information in the question and paragraph.\nIn the question, we see that Jim increased the speed of water flowing out of his bathroom tap. In the paragraph, the key information is that faster-moving water has more energy.\nStep 2: Relate the key information between the question and paragraph.\nSince the water speed has increased, according to the information from the paragraph, the moving water would have more energy.\nStep 3: Answer the question.\nIf Jim increased the speed of water flowing out of his bathroom tap, the amount of energy being transferred by the water would increase."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Here is a goal: Keep a bag of chips closed without bag clips.\nWhich way makes more sense to accomplish this goal?\nChoose your answer from:\n - Use clothespins or paperclips\n - Use a needle and thread.\n\nAssistant: To accomplish the goal of keeping a bag of chips closed without bag clips, let's compare the two options step-by-step and justify each step:\nOption 1 - Use clothespins or paperclips:\n1. Availability: Clothespins and paperclips are more commonly found in households and offices than needle and thread, making them more accessible.\n2. Ease of use: Clothespins and paperclips can be easily clipped onto the folded or rolled edge of the chip bag, effectively sealing it.\n3. Reusability: Both clothespins and paperclips can be reused multiple times on different bags or for other purposes.\n4. Time and effort: Using clothespins or paperclips is a quick and easy process that requires minimal time and effort.\nOption 2 - Use a needle and thread:\n1. Availability: Needle and thread may not be readily available in all households, and using it for closing a chip bag can be seen as a less practical option.\n2. Ease of use: Using a needle and thread involves threading the needle, puncturing the bag, and sewing it closed. This may leave tiny holes in the bag and has a higher chance of spilling the chips.\n3. Reusability: While needle and thread can be reused, the process of opening and closing the bag each time can be tedious.\n4. Time and effort: Using a needle and thread requires more time and effort than simply clipping something onto the bag.\nBased on this comparison, it makes more sense to accomplish the goal of keeping a bag of chips closed without bag clips by using clothespins or paperclips. They are more accessible, easier to use, reusable, and require less time and effort than using a needle and thread."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Geçtiğimiz ay Londra'da radyasyon zehirlenmesinden ölen eski Rus istihbarat ajanı Alexander Litvinenko cinayetini soruşturan İngiliz, Rus ve Alman polisi arasında bilgi akışını hızlandırmak amacıyla Interpol'ün de devreye girmesi bekleniyor.\nWhich language is this?\n\nAssistant: This language is Turkish. It's a language spoken mainly in Turkey and Cyprus. Turkish belongs to the Turkic language family, and it's written in the Latin alphabet. This text is about a former Russian intelligence agent, Alexander Litvinenko, who died from radiation poisoning, and the cooperation between British, Russian, and German police in the investigation, with the involvement of Interpol being expected to speed up the information exchange."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Translate the following sentence to Czech:\nEveryone's gone.\n\nAssistant: When we say \"Everyone's gone\" in Czech, we say \"Všichni odešli.\" Here's what the words mean:\n\"Všichni\" = \"Everyone\", meaning all the people.\n\"odešli\" = \"gone\", meaning they left or went away.\nSo, the whole sentence \"Všichni odešli\" is telling us that all the people have left or went to another place."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: This could have been avoided. If Roy Hodgson had said he wanted to give Adam Lallana the chance to build on his eye-catching efforts against San Marino, we would have been none the wiser about Raheem Sterling’s fatigue. Every word the England manager says is scrutinised. I can see where Hodgson was coming from but an issue has been created that both men could have done without. Had Hodgson explained he wanted to use Sterling for an impact – which he had, by winning the decisive free-kick against Estonia – there would be no quarrel. Now, though, Sterling is in the spotlight and contending with questions such as ‘how can a 19-year-old be tired?’ but those who believe he has made a mistake saying he wasn’t in the right condition to play for England are not looking at the full picture. VIDEO Scroll down to watch Roy Hodgson reveal Raheem Sterling was too tired to start. England's Raheem Sterling looks quite relaxed on the substitutes' bench during the first half against Estonia. Jordan Henderson makes way for Sterling in the 64th minute in Tallinn as England went on to beat Estonia 1-0. Sterling told England manager Roy Hodgson that he was not 100 per cent and that he was too tired to start. Of the players England used against Estonia, only Gary Cahill has played more minutes this season than Raheem Sterling. 1164 - Gary Cahill. 1079 - Raheem Sterling. 1064 - Jack Wilshere. 1035 - Jordan Henderson. 995 - Calum Chambers. 990 - Joe Hart. 990 - Leighton Baines. 919 - Phil Jagielka. 836 - Danny Welbeck. 758 - Wayne Rooney. 741 - Fabian Delph. 538 - Adam Lallana. It doesn’t matter whether you are 19 or 29, if you play nine games in a month, like Raheem did in September, it will leave a mark on you physically. So, in some respects, he has been brave to confront the issue and explain his concerns to Hodgson. Would I have ever done it? No. I feared if I missed a game, someone would come in and take my place. I also never wanted to give managers an excuse to leave me out, so there were times when I played in games I shouldn’t have done. Was it the right thing to do? No. Injuries are what I would mask but I remember a game at Bolton in September 2006 when I was shattered, having been to the World Cup in Germany. We lost 2-0, one of the goals came from a cross into the box. I tried to jump for a header but had nothing in my legs. Had I been right, I know I would have got my head to the ball first. As a central defender, you can find a way of getting through a game, you don’t have to operate at the rapier pace that Raheem moves with. His game is based around high intensity sprinting and he needs sharpness to go past opponents, with or without the ball. The England substitutes watch on as 19-year-old Sterling (second right) was rested from the starting line-up. 32 - Jagielka, Lambert. 29 - Baines. 28 - Rooney, Cahill. 27 - Hart. 26 - Lallana. 24 - Henderson, Delph. 23 - Welbeck. 22 - Wilshere. 21 - Oxlade-Chamberlain. 19 - Chambers, Sterling. We want Raheem to thrill us, to be an exciting figure for many years to come. We don’t want him to get to 25 and his best form be a thing of the past. If he is managed properly now and looks after himself, he won’t suffer burn out like another former Liverpool and England teenage prodigy. A few weeks ago, I spent some time with Sir Alex Ferguson. During our conversation, he remembered how, in 1997, Michael Owen and I had played at the Under-20 World Cup in Malaysia. Michael came back and went straight into Liverpool’s first team at the start of the season. By contrast, the United players in that England squad – Ronnie Wallwork and John Curtis – were both given a month off to recover but Michael was on a rapid upward curve and at the end of that campaign, he set the World Cup finals in France alight. Michael wanted to play in every game for club and country but, in October 1998, Gerard Houllier and Roy Evans – in their brief reign as Liverpool’s joint managers – decided to rest him for a UEFA Cup game against Valencia. Michael wasn’t happy and said in an interview “I’ll rest when I’m 40”. Former Liverpool and Englad striker Michael Owen wanted to play every game for club and country. Sterling told Hodgson he was feeling tired and not in the right mental state during training on Saturday evening. VIDEO Sterling defends England tiredness. That comment went down well with supporters and ex-players but was he right? No. Michael played far too much for Liverpool and England too early, never getting a summer off to recharge physically or mentally. Was it a coincidence that he suffered a serious hamstring injury in April 1999? Raheem is becoming like Michael, with stature and his pace and his importance to club and country at such an early stage of his career but there have been signs over the last few weeks that such a punishing spell is beginning to catch up with him. He may have only played two games at the World Cup but Sterling was training every day in stifling conditions, having just come off a season in which he had been a key figure in Liverpool’s title charge. After three weeks off, he then went to the USA for Liverpool’s pre-season tour. Don’t underestimate what an impact that schedule will have had on him, physically and mentally. Figures back that up. Between May 30, 2013 and October 13, 2013, Raheem had played 615 minutes of football. In the same period this year, that figure had shot up to 1682 minutes. Sterling may have only started two games at the 2014 World Cup but the England winger was in training daily. Sterling came on in the 64th minute but still received criticism from those on Twitter for his decision. Sterling, warming up during the first half, complained of feeling tired in training ahead of the Euro 2016 qualifier. Hodgson and Brendan Rodgers will talk about how to handle him going forward but, naturally, they are going to want to use him when they can as he is a match-winner. That’s why he played against Middlesbrough in the League Cup. Liverpool’s poor start in the league meant Rodgers needed his quality. Liverpool could not have afforded to be on the end of a giant-killing that night, so the plan would have been to get the game won and get Sterling off early. As it was, things never went to plan and he subjected to 120 minutes. Since then, his performances have been notably lacking sparkle. Sterling retweeted those criticising him online before writing: 'Now listening to... Excuse me for being human' His importance at Anfield is vital and that is why, at the end of last season, Rodgers called him “the best young player in Europe”. He was dazzling in that spell but that was a spell when Liverpool were only playing one game a week. This is the first time he has played consistently in games of this pressure and will take some adjustment. But If he is become one of the top players in the world, he will ultimately have to get used to an unrelenting schedule.\nWrite highlights for this article.\n\nAssistant: 1. Raheem Sterling faced controversy over expressing fatigue and not being in the right condition to play for England.\n2. Sterling's game requires high intensity sprinting and sharpness to go past opponents.\n3. Managing Sterling's schedule and ensuring proper rest can prevent burnout and long-term injuries.\n4. Sterling's importance to both Liverpool and England at an early stage in his career puts more pressure on him.\n5. Communication between England manager Roy Hodgson and Liverpool manager Brendan Rodgers is crucial to handle Sterling's development and progress."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: By. Thomas Burrows for MailOnline. Cameraman Huw Walters allegedly forged his wife's signature to pay off £59,000 in debts. He denies three counts of fraud at Cardiff Crown Court. An award-winning cameraman forged his wife's signature to pay off £59,000 in debts, before divorcing her, a court heard. Huw Talfryn Walters, 50, allegedly spent the money on buying expensive camera and photographic equipment. But the court heard the freelance cameraman kept his spiralling money problems secret from his wife, Helen Bowen. The case was brought before a new jury after the original jury failed to reach a majority verdict in March this year. Cardiff Crown Court heard yesterday how Walters forged Miss Bowen's signature three times to take £59,000 from their joint mortgage on their marital home in Rhiwbina, Cardiff. During one application, he wrote: 'We need these funds urgently, so your speed in these matters would be welcome.' He then ended their 14-year marriage after telling his wife he needed to put their house up for sale. Walters allegedly told Miss Bowers: 'I'm divorcing you because if I don't I'm going to go bankrupt.' Prosecutor John Davies said: 'Miss Bowen made a career change around 2005 and after that left the mortgage and bills to her husband who she trusted. 'They separated in December 2011 and not long after that she became aware the mortgage had increased without her agreement. 'She contacted the bank to obtain copies of the three withdrawals which increased the size of the mortgage by £59,000. 'She realised the signatures on them didn't belong to her - they were written by her husband without her knowledge while they still lived together.' Freelance cameraman Walters has worked on programmes including the BBC Proms, Doctor Who and Channel 4's history show Time Team. He has won two Welsh Baftas for his cinematography and camera work in recent years. The court heard Walters spent thousands of pounds on expensive camera equipment. Mr. Davies told the jury: 'A total of £59,000 was acquired using the fake. signature of his then wife. He entered each signature having not. consulted her or asked to consent.' Scroll down for video. Cardiff Crown Court (pictured) heard how Walters forged Miss Bowen's signature three times to take £59,000 from their joint mortgage on their marital home in Rhiwbina, Cardiff. The court heard from a forensic handwriting expert, who proved Helen Bowen's signature had not been written by her - and was most likely written by her husband. Miss Bowers, a professional reflexologist, said she only discovered her husband had forged her signature a few months after their separation - when she was contacted by her mortgage provider. When asked if she agreed to the mortgage withdrawals, she replied: 'No, not at all. If I had been asked I would not have agreed to them.' Walters, now of Maesteg, in Bridgend, denies three charges of fraud. The trial continues. Sorry we are not currently accepting comments on this article.\nHere is a summary of the highlights for this article:\n\nAssistant: Cameraman Huw Talfryn Walters is accused of forging his wife's signature to pay off £59,000 in debts before divorcing her. Walters allegedly spent the money on expensive camera equipment and kept his financial issues a secret from his wife, Helen Bowen. The case is being heard at Cardiff Crown Court after the original jury failed to reach a majority verdict in March. Walters is accused of forging Bowen's signature three times to take money from their joint mortgage. He denies three charges of fraud, and the trial is ongoing."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Staff is always attentive and friendly and on occasion the owner stops by to chit chat and see how service and food is, which is nice to see! Rose and Brandon are always a pleasure to be around and give great service and conversation! Food is excellent as usual, great menu for a bar! Friday happy hour is outstanding, and the bar is not crazy loud even though it is very busy!\nIs this review positive or negative?\nPick your answer from:\nI. negative.\nII. positive.\nAnswer:\n\nAssistant: To determine if the review is positive or negative, I'll analyze the key points mentioned in the review.\n1. \"Staff is always attentive and friendly\" - Positive\n2. \"the owner stops by to chit chat and see how service and food is\" - Positive\n3. \"Rose and Brandon are always a pleasure to be around and give great service and conversation!\" - Positive\n4. \"Food is excellent as usual, great menu for a bar!\" - Positive\n5. \"Friday happy hour is outstanding\" - Positive\n6. \"the bar is not crazy loud even though it is very busy\" - Positive\nBased on the analysis, the review contains positive feedback and sentiments.\nAnswer: II. positive."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Le message de l'orateur au Gouvernement espagnol, c'est que Gibraltar ne fera jamais partie de l'Espagne, que la décolonisation du territoire aura lieu conformément à la volonté de son peuple et qu'il n'y a rien que l'Espagne puisse faire à cet égard.\nWhich language is this?\n\nAssistant: This language is French. It's a beautiful language spoken mostly in France, but also in many other countries like Canada, Belgium, and some African countries. The message talks about Gibraltar, which is a special place, and how it won't be part of Spain."} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Choose between \"sandbox\" and \"patchy grass\". Question: The toddler moved faster in the sandbox than in the patch of grass because it had (A) more resistance (B) less resistance\n\nAssistant: The toddler moved faster in the sandbox than in the patch of grass because it had (B) less resistance."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: How would one describe the sentiment of this tweet?\n@clairemsp tians alot\n\nAssistant: To provide an in-depth description of the sentiment of the tweet \"@clairemsp tians alot\", one needs to first dissect and analyze the various elements present in it. There are a few important assumptions we have to make in order to analyze the tweet, as the text has some errors and is therefore open to interpretation. \n1. Context: It is crucial to consider the context surrounding the tweet, such as the subject matter, any previous conversation, the Twitter user's profile, and recent related events. A sentiment analysis based solely on the text may provide a skewed or misleading interpretation.\n2. Text interpretation: Taking the given tweet at face value, it appears to contain a typing error. We can assume that the correct version could be \"@clairemsp means alot\" or \"@clairemsp thanks a lot\", with the latter being a more plausible interpretation as 'tians' could possibly be a misspelled or typo version of 'thanks'. This interpretation suggests a message of gratitude. Semantically, this phrase signifies a positive sentiment.\n3. User's Tone: Analyzing the specific tweet, we can observe that the user is directly addressing \"clairemsp\" and sharing that something \"means a lot\" or they are thankful. This could imply that clairemsp has contributed to something valuable or helpful; however, the tweet lacks specific details. Despite that, overall, the tone of the tweet appears to be appreciative and possibly positive.\nConsidering these elements, the sentiment of the tweet can be described as predominantly positive. Here, the author is expressing gratitude to the person they've mentioned in the tweet and is likely appreciative of their help, contribution, or impact. It is essential to note that understanding the broader context and tone of the conversation or topic would provide better insight and accuracy in determining the tweet sentiment. However, based on the available information, it can be concluded that the sentiment expressed in this tweet appears to be positive."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Read the following paragraph and determine if the hypothesis is true:\nIn mathematics, the Vitali–Carathéodory theorem is a result in real analysis that shows that, under the conditions stated below, integrable functions can be approximated in L from above and below by lower- and upper-semicontinuous functions, respectively. It is named after Giuseppe Vitali and Constantin Carathéodory.\nHypothesis: Vitali–Carathéodory theorem involves mathematics\n\nAssistant: After reading the paragraph provided, it can be inferred that the hypothesis is true. The Vitali–Carathéodory theorem involves mathematics, particularly in the field of real analysis. Real analysis is a branch of mathematics that deals with the study of real numbers, their properties, and the properties of functions and sequences defined on the set of real numbers.\nThe theorem is named after Giuseppe Vitali and Constantin Carathéodory, two mathematicians who studied and made significant contributions to this area of mathematics. The theorem states that under certain conditions, integrable functions (i.e., functions whose integrals or areas under the curve exist) can be approximated in the L space (a function space in mathematics that deals with the study of certain types of functions; specifically, functions that are measurable and have finite norms) from above and below by lower- and upper-semicontinuous functions, respectively.\nLower-semicontinuous functions are functions that have the property that the limit inferior (or greatest lower bound) of the function at a point is less than or equal to the value of the function at that point. Upper-semicontinuous functions, on the other hand, have the property that the limit superior (or least upper bound) of the function at a point is greater than or equal to the value of the function at that point.\nIn summary, the Vitali–Carathéodory theorem is an important result in mathematics dealing with the properties of integrable functions, their approximations using lower- and upper-semicontinuous functions, and their applications to the study of real numbers and real-valued functions. As such, the hypothesis is true: the theorem is indeed a part of mathematics."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: One of the best places to stay in Vegas!!! Love the location, the decor, the fountain with the synchronicity of the music that you can hear from the room. The shops are great and the shows that they offer and to me the casino is the best compare to the others. The pool area is huge with different areas and cabanas are great and spacious. Great restaurants and among all of the hotels that I have been I think it is family oriented without kids bothering .... strange but true!!! Love the hotel!!!\nIs this review positive or negative?\nSelect from:\n(i) negative\n(ii) positive\nAnswer:\n\nAssistant: (ii) positive"} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Continue writing the following text.\nIt took a long time for Mary to heal from her fallout with Tanya as\n\nAssistant: the depth of their friendship had made the betrayal even more painful. Mary could still remember the day when Tanya revealed her true colors, leaving her in utter shock. The first step in moving forward was for Mary to fully grasp the situation and come to terms with the fact that Tanya was not the person she thought she knew.\nStep 1: Acceptance\nIn order to heal, Mary had to accept that Tanya was a completely different person from the loyal friend she had trusted for so long. It was challenging as she went over the countless memories they had shared, wondering if Tanya had been pretending the entire time. The process of acceptance was slow and difficult, but it was essential in order for Mary to let go and find closure.\nStep 2: Emotional Expression\nMary found it crucial to express and release the hurt, anger, and sadness that she felt as a result of Tanya's actions. She chose to write in a journal, where she could vent her feelings honestly and privately, without the fear of judgment from others. This outlet allowed her to process her emotions and gain a clearer understanding of her thoughts and experiences.\nStep 3: Seeking Support\nAs Mary worked through her feelings, she reached out to her friends and family for support. They provided her with the much-needed love, encouragement, and a sense of belonging as she faced her challenging emotional journey. Their presence acted as a reminder that despite the loss of one friendship, she still had others who cared for her deeply.\nStep 4: Forgiveness\nAlthough it took a considerable amount of time, Mary eventually reached a point where she was able to forgive Tanya. Forgiveness didn't mean that she condoned Tanya's actions, but rather, she let go of the resentment and negative emotions she had been holding onto. Forgiving Tanya allowed Mary to free herself from the hold that the betrayal had on her life, and in turn, regain her emotional well-being.\nStep 5: Moving Forward\nWith forgiveness achieved, Mary began to rebuild her life without Tanya. She focused on strengthening her relationships with her existing support network, and even forged new connections along the way. While the betrayal was painful, it ultimately pushed Mary towards personal growth and lessons that she would carry with her throughout her life.\nIn the end, the long and intertwining road towards healing taught Mary several valuable lessons about trust, friendship, and resilience. As she closed the chapter on her fallout with Tanya, Mary emerged stronger and more self-aware, ready to embrace new experiences and authentic connections with those who truly deserved her trust."} +{"text": "System: Explain how you used the definition to come up with the answer.\n\nUser: Definition: In this task, you are given a set of context paragraph and some supporting facts to answer a question. Your task is to generate answer for given question based on set of context paragraphs and supporting facts.\nInput: Context_1 : Billy Galvin is a 1986 film directed by John Gray. It stars Karl Malden and Lenny von Dohlen. Context_2 : Honey, I Shrunk the Kids: Movie Set Adventure was a playground at Disney's Hollywood Studios in Orlando, Florida. It was located in the Streets of America area and was based on Disney's 1989 film, \"Honey, I Shrunk the Kids\". The playground, which opened almost a year after the film, featured 30 ft blades of grass, and was themed as a movie set for the giant backyard scenes from the film. It was closed in April 2016 to make room for the construction of on the same site. Context_3 : Movement and Location (also called Movement + Location) is an American science fiction movie set in modern-day Brooklyn, directed by Alexis Boling. It stars Bodine Boling, Catherine Missal, Brendan Griffin, Anna Margaret Hollyman, David Andrew Macdonald and John Dapolito. \"Movement and Location\" tells the story of Kim Getty, an immigrant from 400 years in the future who is sent back in time to live an easier life. It premiered at the 2014 Brooklyn Film Festival where it won the Audience Award, Best Screenplay and Best Original Score. Context_4 : Dracula’s Widow is a 1988 vampire thriller film directed by Christopher Coppola , written by Tom Blomquist as Kathryn Ann Thomas, and starring Sylvia Kristel, Josef Sommer and Lenny von Dohlen. Kristel, in the titular role, goes on a killing spree in a seedy 1980s Hollywood, while in search of her husband. Context_5 : Entertaining Angels: the Dorothy Day Story is a 1996 independent film about the life of Dorothy Day, the journalist turned social activist and founder of the \"Catholic Worker\" newspaper. The film stars Moira Kelly as Day, Heather Graham, Lenny Von Dohlen and Martin Sheen. Context_6 : Lenny Von Dohlen (born December 22, 1958) is an American film, television and stage actor, known for his performance as architect Miles Harding in the film \"Electric Dreams\" and as the orchid-growing agoraphobic Harold Smith in the television series \"Twin Peaks\". Von Dohlen currently resides in New York and Los Angeles. Context_7 : Scott Fivelson is the writer/director of \"Near Myth: The Oskar Knight Story\", a Hollywood biopic starring Lenny Von Dohlen as legendary director, \"Oskar Knight\". The film won Scott Fivelson the \"Breakthrough Director Spotlight\" award at the 2016 Hollywood Reel Independent Film Festival. Context_8 : Electric Dreams is a 1984 American-British science fiction romantic comedy-drama film set in San Francisco, California, that depicts a love triangle between a man, a woman and a personal computer. It stars Lenny Von Dohlen, Virginia Madsen, Maxwell Caulfield, and the voice of Bud Cort and was directed by Steve Barron. It was the first film released by the Virgin Films production company. Context_9 : The film concerns a group of high school buddies before moving the action forward to reveal the characters two decades later. Reggie (Adam Lamberg) is the wild card of this group of friends. Perpetually single and lacking ambition, he is prone to making inappropriate comments to girls. When the object of his affection breaks up with her bad-boyfriend, Reggie clumsily asks for a date. Things almost happen, but not quite, and the promise of a kiss still haunts Reggie (Lenny Von Dohlen) 20 years later. On the other hand, Diego (Giancarlo Carmona) and Bonnie (Caker Folley), dated all through high school. But when Diego went to Mexico to help his family, Bonnie (Cynthia Gibb) lost whatever faith she had in men. Now she's a tough lawyer with a young son and a husband she can't trust (John Schneider). When Diego (Eddie Velez) returns and unexpected emotions surface. Finally, Morgan and Dennice were high school sweethearts who married as teens and are still happy together, but secrets suddenly overwhelm them. Context_10 : Alamo Village is a movie set and tourist attraction north of Brackettville, Texas, United States. It was the first movie location built in Texas, originally constructed for and best known as the setting for \"The Alamo\" (1960), directed by John Wayne and starring Wayne, Richard Widmark, Laurence Harvey and Frankie Avalon. fact_1 : Lenny Von Dohlen (born December 22, 1958) is an American film, television and stage actor, known for his performance as architect Miles Harding in the film \"Electric Dreams\" and as the orchid-growing agoraphobic Harold Smith in the television series \"Twin Peaks\". fact_2 : Electric Dreams is a 1984 American-British science fiction romantic comedy-drama film set in San Francisco, California, that depicts a love triangle between a man, a woman and a personal computer. Question: Lenny Von Dohlen was an architect in the 1984 movie set in what location?\nOutput:\n\nAssistant: San Francisco, California\nI used the definition to understand that I need to find the movie in which Lenny Von Dohlen played an architect role in a 1984 movie. I found this information in Context_8, where \"Electric Dreams\" is mentioned as a 1984 movie set in San Francisco, California, and Lenny Von Dohlen is mentioned as playing an architect."} +{"text": "System: User will you give you a task with some instruction. Your job is follow the instructions as faithfully as you can. While answering think step-by-step and justify your answer.\n\nUser: You will be given a definition of a task first, then some input of the task.\nGenerate an appropriate title for the given text. The generated title must be short and include the main topic of the text. The preferred titles are under fifteen words.\nDave LeeNorth America technology reporter Reuters reported that the companies were in negotiations but no deal had been agreed. The move would enable Google to produce its own smartwatches. It already develops a version of Android for wearables. Fitbit's share price increased by 19% before being suspended. Google has not commented. A spokeswoman for Fitbit said it \"does not comment on rumours or speculation\". Reuters did not say what sum had been offered. It emerged in September that Fitbit, currently valued by the markets at around $1.4bn (£1.1bn) after Monday's surge, was touting itself for a sale. Fitbit, which in its latest quarterly earnings posted a loss of $68.5m, is looking precarious as a standalone entity since bigger firms have gained a strong footing in wearable technology. \"A key tipping point is likely to have been Apple's decision to price its Series 3 Apple Watch at $199/£199,\" said Leo Gebbie, from the consultancy CCS Insight. \"That will have put immense pressure on Fitbit's own products, which are already feeling the challenge from low-cost rivals such as [China's] Xiaomi.\" In an attempt to diversify its income, Fitbit has expanded into other areas - including monitoring users' breathing during sleep, and diabetes management. But the firm trails Huawei, Apple and Xiaomi in fitness tech sales. For Google the deal, if confirmed, would represent a renewed effort to develop its Wear OS platform. \"Although Google has been lukewarm in its commitment to wearables recently, this indicates it is serious about the segment and could provide a strong boost to its ambitions,\" Mr Gebbie added. CCS Insight predicted 142 million wearable devices would be sold worldwide this year, at a value of $17.1bn. Meanwhile, Google's parent company Alphabet released its third quarter results which showed a drop in net profit to $7bn in the three months to 30 September. Analysts had been expecting profit of $8.8bn. It fell from $9bn in same period last year when changes to US tax law had boosted Alphabet's bottom line. Total third quarter revenue rose by 20% to $40.5bn.\nOutput:\n\nAssistant: \"Google in Talks to Acquire Fitbit for Wearable Tech Expansion\""} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Translate to Finnish:\nTrafficking in human beings and prostitution continue to reach alarming levels in the European Union, and further socio-economic policies, high-quality public services and effective measures are required to solve the problems of poverty and social exclusion.\n\nAssistant: Step 1: Identify key phrases and terms to be translated.\n- Trafficking in human beings\n- Prostitution\n- Alarming levels\n- European Union\n- Socio-economic policies\n- High-quality public services\n- Effective measures\n- Poverty\n- Social exclusion\nStep 2: Translate each key phrase and term to Finnish\n- Trafficking in human beings: ihmiskauppa\n- Prostitution: prostituutio\n- Alarming levels: hälyttävät tasot\n- European Union: Euroopan unioni\n- Socio-economic policies: sosioekonomiset politiikat\n- High-quality public services: korkealaatuiset julkiset palvelut\n- Effective measures: tehokkaat toimenpiteet\n- Poverty: köyhyys\n- Social exclusion: sosiaalinen syrjäytyminen\nStep 3: Build a Finnish sentence using the translated phrases and terms, maintaining a coherent structure and accurate grammar.\nIhmiskauppa ja prostituutio jatkavat hälyttävien tasojen saavuttamista Euroopan unionissa, ja ongelman ratkaisemiseksi köyhyyden ja sosiaalisen syrjäytymisen osalta tarvitaan lisää sosioekonomisia politiikkoja, korkealaatuisia julkisia palveluja ja tehokkaita toimenpiteitä.\nTranslation: Ihmiskauppa ja prostituutio jatkavat hälyttävien tasojen saavuttamista Euroopan unionissa, ja ongelman ratkaisemiseksi köyhyyden ja sosiaalisen syrjäytymisen osalta tarvitaan lisää sosioekonomisia politiikkoja, korkealaatuisia julkisia palveluja ja tehokkaita toimenpiteitä."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Cavemen was by far the biggest load of crap I have ever wasted my time watching. This show based on the Geico commercials is less entertaining then an actual 30 sec ad for Geico. The makeup was half ass-ed to say the least, hard to imagine a caveman with prefect white teeth even after going to the dentist. This show could of had potential for a funny series if they could of gotten the cast from the commercials, that in it self makes for a lousy show. Perhaps if the writers were the same from the Geico ads this may of had a chance, instead the pilot lacked a good story line. I give this show a 1 out of 10, I would of liked to put a zero out of 10 but that was not an option. I pray for a quick death to this show, I'd give it less then 5 episodes before it dies a deserving death.\nIs the sentiment of this review positive or negative?\n\nAssistant: The sentiment of this review is clearly negative. The reviewer expresses strong dissatisfaction with the show, stating it was a waste of time, criticizing the makeup and casting choices, and wishing for it to quickly end. They also give it a low rating, indicating their overall disapproval of the show."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: În acest scop - după cum aţi propus - trebuie să începem negocierile cu preşedinţia Consiliului cât mai repede posibil, pentru a hotărî o serie de modificări.\nCould you please translate this to English?\n\nAssistant: For this purpose - as you have suggested - we must begin negotiations with the presidency of the Council as soon as possible, in order to decide on a series of changes."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Answer the following question: Question: \"Sassou and the Soviet union signed a pact based on what?\" Context: \"Sassou Nguesso aligned the country with the Eastern Bloc and signed a twenty-year friendship pact with the Soviet Union. Over the years, Sassou had to rely more on political repression and less on patronage to maintain his dictatorship.\" Answer:\nAnswer:\n\nAssistant: Sassou and the Soviet Union signed a pact based on a twenty-year friendship and alignment with the Eastern Bloc.\nStep 1: Read and understand the given context.\nStep 2: Identify the key points - Sassou Nguesso, Eastern Bloc alignment, twenty-year friendship, Soviet Union.\nStep 3: Determine the main focus of the question - wanting to know what the pact was based on.\nStep 4: Match the relevant key points from the context with the question's focus.\nStep 5: Formulate a concise answer using the matched key points from the context."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Options are:\n- No;\n- Yes;\nI am verifying the answers generated by an automatic system to the following question: where are statutes formed Suggested answer: This word is used in contradistinction to the common law . Should I validate this answer?\nAnswer:\n\nAssistant: Step 1: Analyzing the user's query\nThe user has provided two options - 'No' and 'Yes.' They are asking if they should validate the given answer to a question. The question is about where statutes are formed, and the suggested answer is related to the common law.\nStep 2: Evaluating the suggested answer\nThe suggested answer does not specifically address where statutes are formed; it only provides a comparative aspect with the common law. Thus, it is insufficient to answer the question correctly.\nStep 3: Conclusion\nSince the suggested answer does not adequately address the question, the appropriate response should be 'No.' It is recommended not to validate this answer because it does not fully answer the question about where statutes are formed."} +{"text": "System: You should describe the task and explain your answer. While answering a multiple choice question, first output the correct answer(s). Then explain why other answers are wrong. Think like you are answering to a five year old.\n\nUser: Context: He doubled up in her arms , then went limp . His body rolled over , his face wedging in the roots of a drift stump . Rin struggled to turn him over again ; small pebbles stuck to the bared front of his body . One hand still clamped over the broad smear of blood .\nQuestion Why was the hand still clamped over the broad smear of blood ?\nPossible answers:\nSelect from the following.\na). The water was cold on the wound and it hurt ..\nb). he was embarrased because he was bleeding ..\nc). It was painful ..\nd). None of the above choices ..\nThe answer:\n\nAssistant: c). It was painful.\nExplanation: The hand was still clamped over the broad smear of blood because it was painful. When someone is injured and bleeding, they often hold onto the wound to try to alleviate the pain or to stop the bleeding. \na) is wrong because the passage does not mention anything about the water temperature affecting the wound.\nb) is also wrong because there is no evidence in the passage that the person would be embarrassed about bleeding.\nd) is incorrect because one of the choices, in this case, choice c, is the correct answer."} +{"text": "System: You are an AI assistant. You should describe the task and explain your answer. While answering a multiple choice question, first output the correct answer(s). Then explain why other answers are wrong. You might need to use additional knowledge to answer the question.\n\nUser: Had dinner last night at Talavera. Very disappointed. Restaurant is beautiful with high end touches. Poor marks for food quality and service. My medium rare buffalo tenderloin was significantly overcooked. No pink was visible. The veggies, on the contrary were undercooked. The carrots were warm on the outside and cold in the middle. I never send food back but was ready to make an exception in this case. However, the waitstaff was nowhere to be found. After receiving my entree 40 min elapsed before anyone followed up to see how we were doing. I don't necessarily blame the servers as they seemed to be significantly understaffed. I expect much, much more from an expensive restaurant at a Four Seasons resort. Extremely disappointing.\nChoose the sentiment of this review?\nPick your answer from: A. negative. B. positive.\n\nAssistant: A. negative\nThe sentiment of this review is negative because the reviewer expresses disappointment in multiple aspects of their experience at Talavera, such as poor food quality, overcooked and undercooked dishes, and slow service. They also mention that they expect more from an expensive restaurant at a Four Seasons resort, further emphasizing their dissatisfaction."} +{"text": "System: You are an AI assistant, who knows every language and how to translate one language to another. Given a task, you explain in simple steps what the task is asking, any guidelines that it provides. You solve the task and show how you used the guidelines to solve the task.\n\nUser: Detailed Instructions: In this task, we can are given an english sentence and the goal is to generate a text in 'turkish' language that is a faithful translation of the input sentence\nQ: CHP\nA:\n\nAssistant: Task Explanation: The task is to translate an English sentence into Turkish, ensuring that the translation is accurate and stays true to the original sentence's meaning.\nGiven English Sentence: \"CHP\"\nGuidelines:\n1. Understand the context of the given English sentence or abbreviation.\n2. Translate the English sentence into Turkish.\nSolution:\nIn this case, \"CHP\" is an abbreviation for the \"Cumhuriyet Halk Partisi\" which is a political party in Turkey. It is not an English sentence, but we can still provide the translation in Turkish since it is already related to the Turkish context.\nUsing Guidelines:\n1. I identified that \"CHP\" is an abbreviation for the Turkish political party \"Cumhuriyet Halk Partisi\".\n2. As it is already in the Turkish context, there is no need to translate. So, the answer is:\nA: Cumhuriyet Halk Partisi"} +{"text": "User: Given the question: Given this review: \"Latest update is making this a battery hog. Used to be out of my Top 10.\" Would you recommend this app to a friend? Not at all, No, Maybe, Yes, or Definitely?\nThe answer is:\n\nAssistant: No"} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Solve -40*o + 63 = -17*o - 26*o for o.\n\nAssistant: To solve the equation -40*o + 63 = -17*o - 26*o for o, follow these steps:\nStep 1: Combine like terms on both sides of the equation.\nThe equation currently has two terms with the variable o on the right side, so combine them:\n-40*o + 63 = (-17*o - 26*o)\n-40*o + 63 = -43*o\nStep 2: Isolate the variable term on one side of the equation.\nIn this case, we'll want to get all the terms with the variable o on the left side of the equation and all the constant terms on the right side. To do this, add 43*o to both sides to eliminate the -43*o term from the right side:\n(-40*o + 43*o) + 63 = -43*o + 43*o\n3*o + 63 = 0\nStep 3: Isolate the variable by moving the constant term to the other side.\nSubtract 63 from both sides to isolate the 3*o term:\n3*o + 63 - 63 = 0 - 63\n3*o = -63\nStep 4: Solve for the variable.\nDivide both sides by 3 to find the value of o:\n3*o / 3 = -63 / 3\no = -21\nThe solution is o = -21."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Loud Snoring
Tim started to snore as he got older. It frustrated his wife. They tried different solutions. None seem to work. Eventually Tim had to take medicine to breath better.\nDoes this next sentence follow, given the preceding text?\nTim was a snorer.\n\nAssistant: Yes, the next sentence, \"Tim was a snorer,\" does follow given the preceding text as it summarizes the main issue explained previously.\nThe preceding text provides context and background about Tim's snoring issue. As Tim got older, he began to snore loudly, which caused frustration for his wife. They attempted various solutions to address this issue, but none of them were effective. Ultimately, Tim had to take medicine to improve his breathing.\nThe summary statement, \"Tim was a snorer,\" is an appropriate follow-up to the preceding text as it concisely encapsulates the central problem discussed. It may not encompass all the specific details mentioned earlier (such as trying different solutions or taking medication), but it accurately represents the overarching theme of Tim's snoring and its impact on his life and his wife's well-being."} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Lee had their driving test coming up so Lee booked more lessons and practiced harder. Given the question \"What does Lee need to do before this?\", is \"book their driving test\" a valid answer?\nChoose from:\n + Yes;\n + No;\nAnswer:\n\nAssistant: No;"} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Do these sentences have the same meaning?\nThe researchers found only empty cavities and scar tissue where the tumors had been .\nNo tumors were detected ; rather , empty cavities and scar tissue were found in their place .\n\nAssistant: Yes, these two sentences essentially convey the same meaning, although they are structured differently and use somewhat different phrasing. Let's break down the meaning of each sentence to better understand their similarities.\nSentence 1: \"The researchers found only empty cavities and scar tissue where the tumors had been.\"\nIn this sentence, the focus is on what the researchers found in the location where tumors previously existed. The use of \"only\" emphasizes that instead of finding any tumors, the researchers discovered \"empty cavities and scar tissue.\" This implies that the tumors are no longer present, and what remains in their place is the aftermath of their removal or disappearance (i.e., the empty cavities and scar tissue).\nSentence 2: \"No tumors were detected; rather, empty cavities and scar tissue were found in their place.\"\nThis sentence is more explicit about the absence of tumors. The first phrase, \"No tumors were detected,\" clearly states that the researchers did not find any tumors. The second part of the sentence, following the semicolon, elaborates on this finding by using the word \"rather\" to indicate that instead of tumors, they found \"empty cavities and scar tissue\" in the location where tumors would be expected. This also implies that the tumors are no longer there and have been replaced by these empty cavities and scar tissue.\nBoth sentences express the idea that the researchers were looking for tumors but discovered empty cavities and scar tissue instead. Consequently, the two sentences convey the same meaning, even though they use different wording and structure."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Determine the topic of the question-answer pair. Question: \"how many presidents of the us?\"; Answer: \"There have been 43 people sworn into office, and 44 presidencies, as Grover Cleveland served two non-consecutive terms and is counted chronologically as both the 22nd and 24th president.\"? Topic:\nA:\n\nAssistant: The U.S. Presidency or U.S. Presidents"} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Read the following article and answer the question. Article: It is easy to say \"to be Best Friends Forever (BFF)\". However, a true BFF helps out when the other is in trouble, not just once or twice, but for a lifetime. Milo is six years old. He has become Eddie's BFF. They are special because they are not humans, but dogs. Their owner, Angie Baker, brought them up. They spent many years playing together. However, last year, Eddie suddenly started walking into dustbins and walls. Soon he lost his eyesight. When Milo felt something was wrong, he decided to act as Eddie's \"guide\" dog. He leads Eddie through the house. When they go for walks, he holds Eddie's leash . Once, Milo had to spend a few nights at the hospital himself. Poor Eddie often bumped into walls and furniture. Ever worse, he got lost during walks. Angie realized how much Eddie depended on his friend. Luckily, Milo is now fine and back at his BFF's side, guiding him through life again. Question: Eddie's BFF is _ . Answer:\nAnswer:\n\nAssistant: Eddie's BFF is Milo."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: In this week's <I>TravelWatch</I> column, <I>National Geographic Traveler</I> geotourism editor Jonathan B. Tourtellot examines a stewardship program that keeps Maine's ocean islands untrampled.\nWhat best summarizes the content of the above article?\npick from the following.\n (1). World.\n (2). Sports.\n (3). Business.\n (4). Science/Tech.\n\nAssistant: To summarize the content of the article, I will go through it step by step:\n1. The article is from the \"TravelWatch\" column, indicating it is related to travel.\n2. It is part of \"National Geographic Traveler,\" which is a reputable source for travel and exploration content.\n3. The geotourism editor's focus in the article is on a stewardship program in Maine.\n4. The purpose of the program is to protect Maine's ocean islands from being trampled on.\nFrom these points, the most suitable summary topic for the content of the article would be:\n(1). World."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: By. Sam Greenhill. and Daniel Martin. Created:. 19:52 EST, 20 February 2012. Jobless people attending courses run by David Cameron's 'back to work' tsar  Emma Harrison were ordered to sign blank timesheets, it was alleged yesterday. Others complained of being forced to waste hours sitting doing nothing instead  of being allowed to look for gainful employment. A damning dossier of allegations about Mrs Harrison's employment agency A4e has been handed to the Comptroller and Auditor General by former minister Fiona  Mactaggart. She is the Labour MP for Slough, Berkshire, where last week A4e's offices were  visited by police as part of a fraud probe. The Prime Minister is facing demands to suspend lucrative Whitehall contracts  which A4e won to help the long-term unemployed find jobs. 'Posh commune': A4e boss Mrs Harrison's £5million pile has 16 bedrooms, 100 acres of land, a pool, spa, a bar, a nightclub and a long dining table for banquets. Millionairess Mrs Harrison, 48 - made. families tsar by Mr Cameron after she  boasted she could help problem. families find work - triggered outrage when she  paid herself. £8.6million of mainly taxpayers' cash. Ministers. have sought urgent reassurance there is no 'systematic fraud' of the. public purse, after Thames Valley Police were called in by the. Department of  Work and Pensions. Ms. Mactaggart's dossier contains accusations from constituents who were. made to  attend an A4e 'back to work' course as a condition of receiving. state benefits. One claimed: 'Every Friday, the \"clients\" have to sign a blank timesheet and  the times are filled in later. 'Members of my group were told to fill in a blank piece of paper outlining any  \"modules\"...and to sign a blank \"review\". 'I would guess this was to fraudulently claim that all \"criteria\" as laid down  by the government was adhered to.' Another. complainant, a construction worker, said he had spent 13 weeks on an. A4e course supposedly to renew a forklift operator's licence yet had. 'received  no training at all', despite repeatedly asking for some. Attendance. on a back-to-work course is mandatory for the long-term unemployed  if. they want to keep being paid benefits, and companies such as A4e charge. taxpayers millions of pounds to run them. But a third person described a 'chaotic, unorganised' A4e course in which 'we  would mostly be shunted in and out of rooms, staring at walls, long silences  with none of the staff telling the group what was happening'. She told the MP her group was forced to 'learn' menial tasks such as how to  write a letter on a computer. She said: 'I did not learn anything that I did  not know already, and it is ridiculous and worrying that people like myself  should be forced into this position where I have a proven work history.' Last night Ms Mactaggart, a former Home Office minister and a member of the  influential Public Accounts Committee, said: 'I think there is a series of what  appears to me on the face of it to be fraudulent actions - expecting people to  sign blank time sheets and so on. Prime Minister David Cameron is facing calls to sack his 'back to work' tsar after a fraud probe was launched. 'It is in itself wrong to get vulnerable unemployed people to sign blank time  sheets, because this company is in the position where it has the power to put  people in breach of their Jobseekers' Allowance conditions. 'In my view, that makes A4e not fit to receive a government contract. The  government should absolutely consider suspending the contracts. 'The concerns that my constituents have raised are sufficient for me to believe  that they are not a proper company to get taxpayers' money and to be dealing  with vulnerable people.' The MP forwarded the complaints to Sir Amyas Morse, the Comptroller and Auditor  General - the head of the National Audit Office which scrutinises the accounts  of government departments and agencies. She demanded an investigation, telling him the cases seemed to be examples of  'A4e taking money from the Government without actually delivering anything  worthwhile'. Other MPs have attacked Mrs Harrison's vast dividend, which she paid herself  after A4e recorded a £234million turnover. A4e - which used to be called Action  For Employment - makes all its UK income from state contracts. Mrs Harrison, worth an estimated £70million, lives in a £7million country  mansion in Derbyshire with her husband Jim, their four children, and a group of  11 close friends and their six children in a 'posh commune'. Labour's Margaret Hodge, who is chairman of the Public Accounts Committee, has led calls for Mrs Harrison's company to have its state contracts suspended. Margaret Hodge, the chairwoman of the Public Accounts Committee, yesterday  tabled a parliamentary question asking whether the Government would suspend the  company's contracts pending inquiries. A Thames Valley Police spokesman said: 'Officers visited the offices of A4e in  Slough as part of an allegation of fraud, which was referred to the Force by  the Department for Work and Pensions. The investigation is ongoing and at an  early stage, therefore we are currently unable to comment any further.' A4e said the police had only come in 'for a meeting' and had not taken away any  documents. It said the fraud probe concerned an isolated incident two years ago  relating to four employees who had since left the company, and that A4e had discovered the alleged fraud itself and reported it to the Department for Work  and Pensions within 48 hours. Last night an A4e spokesman insisted the concerns raised by Ms Mactaggart had  been resolved, saying: 'Under the new Work Programme, customers are no longer  required to fill out timesheets. 'All our staff have clear guidelines on how to complete claims made under Work  Programme and these claims are independently verified by the employer and by  JobCentre Plus. 'A4e has a zero tolerance policy towards any form of behaviour that might lead  to allegations of fraud. 'The concerns raised by Fiona Mactaggart's constituents are not linked to the  current Thames Valley Police investigation.'\nHere is a summary of the highlights for this article:\n\nAssistant: - David Cameron's 'back to work' tsar Emma Harrison's employment agency A4e has been accused of ordering jobless people to sign blank timesheets.\n- Former minister Fiona Mactaggart has handed a damning dossier of allegations about A4e to the Comptroller and Auditor General.\n- A4e's offices in Slough were visited by police as part of a fraud probe.\n- The Prime Minister is facing demands to suspend lucrative Whitehall contracts awarded to A4e.\n- A4e has denied any systematic fraud, and claims that the police visit was for a meeting, not to seize documents.\n- Emma Harrison has previously faced controversy for paying herself £8.6 million in taxpayer-funded dividends."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Here is some data: Cocum eatType coffee shop; Cocum food Japanese; Cocum priceRange £20-25; Cocum customer rating high; Cocum familyFriendly no.\nWrite a sentence that describes this data:\n\nAssistant: Cocum is a coffee shop that specializes in serving Japanese cuisine, with a price range of £20-25 per person, highly appreciated by customers for their delightful taste, but not particularly considered family-friendly."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: The last remaining personal property belonging to Ponzi schemer Bernie Madoff and his wife, Ruth, netted nearly $500,000 in an auction in Miami Beach, Florida, the U.S. Marshals Service said Tuesday. Proceeds from the June 4 auction go to the U.S. Department of Justice Assets Forfeiture fund and will be used to compensate \"the many victims of this crime of historical proportions,\" said Neil DeSousa, U.S. marshal for the Southern District of Florida. \"Rest assured that the U.S. Marshals Service's top priority in this case was restitution toward the victims.\" To date, about $24 million in property sales and $80 million in cash assets from the Madoffs have been recovered, the Marshals Service said. Art, antiques, furniture and clothing from the Madoff residences in New York and Palm Beach, Florida, were featured in the 275 lots included in the final auction. Madoff, the mastermind of an estimated $50 billion Ponzi scheme, pleaded guilty in March 2009 to 11 felony charges, including money laundering, perjury, false filings with the SEC and fraud. He was sentenced to 150 years in prison.\nWhat are the important parts of this article?\n\nAssistant: 1. The last remaining personal property of Bernie Madoff and his wife, Ruth, was auctioned off in Miami Beach, Florida, and raised nearly $500,000.\n2. The U.S. Marshals Service announced that the proceeds from the June 4 auction will go to the U.S. Department of Justice Assets Forfeiture fund.\n3. The funds will be used to compensate the victims of Madoff's Ponzi scheme, which was a crime of historical proportions.\n4. To date, about $24 million in property sales and $80 million in cash assets from the Madoffs have been recovered by the U.S. Marshals Service.\n5. The final auction featured 275 lots, including art, antiques, furniture, and clothing from the Madoff residences in New York and Palm Beach, Florida.\n6. Bernie Madoff was the mastermind of an estimated $50 billion Ponzi scheme and pleaded guilty in March 2009 to 11 felony charges. He was sentenced to 150 years in prison."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: This is some data: Alan Bean, almaMater, \"UT Austin, B.S. 1955\"; Alan Bean, timeInSpace, \"100305.0\"(minutes); Alan Bean, selectedByNasa, 1963.\nGenerate a detailed description of this data.\nSentence:\n\nAssistant: The given data provides valuable information about Alan Bean, a notable astronaut. Bean attended the University of Texas at Austin, where he graduated with a Bachelor of Science degree in 1955, as indicated by the almaMater attribute. His time in space amounts to an impressive 100,305.0 minutes, which reveals his extensive experience and dedication to space exploration. Lastly, we learn that he was selected by NASA in 1963, marking the beginning of his career as an astronaut and his significant contributions to expanding human knowledge of outer space."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Add spaces between the words in the following text: SolomonruledoverallthekingdomsfromtheRivertothelandofthePhilistines,andtotheborderofEgypt:theybroughttribute,andservedSolomonallthedaysofhislife.\n\nAssistant: 1. Identify the text provided: \"SolomonruledoverallthekingdomsfromtheRivertothelandofthePhilistines,andtotheborderofEgypt:theybroughttribute,andservedSolomonallthedaysofhislife.\"\n2. Locate possible word boundaries by recognizing common words, proper nouns, and prefixes/suffixes.\n3. Begin separating words with spaces:\n - \"Solomon ruled overall the kingdoms from the River to the land of the Philistines, and to the border of Egypt: they brought tribute, and served Solomon all the days of his life.\"\n4. Reassess the text and verify if all word boundaries have been addressed:\n - The word \"overall\" should be separated into two words \"over all.\"\n5. Finalize the text with correct word boundaries:\n - \"Solomon ruled over all the kingdoms from the River to the land of the Philistines, and to the border of Egypt: they brought tribute, and served Solomon all the days of his life.\"\n6. Task complete: The words in the provided text have been separated with spaces to make it more readable."} +{"text": "System: User will you give you a task with some instruction. Your job is follow the instructions as faithfully as you can. While answering think step-by-step and justify your answer.\n\nUser: Teacher:In this task you will be given a list of dictionaries. A dictionary is a set of key-value pairs, where each key is unique and has a value associated with that key. You should sort the list of dictionaries from smallest to largest by their 'first' key. If there is two dictionaries with the same 'first' value then sort them by their 'second' key. Negative numbers should come before positive numbers.\nTeacher: Now, understand the problem? Solve this instance: [{'first': 12, 'second': 17}, {'first': -88, 'second': 99}, {'first': -90, 'second': -3}, {'first': 64, 'second': -24}, {'first': -23, 'second': -4}, {'first': -6, 'second': -62}, {'first': 2, 'second': 0}]\nStudent:\n\nAssistant: First, we need to sort the dictionaries by their 'first' key value. We will use the built-in Python 'sorted' function which provides a way to sort elements by providing a function to compute a sort key from an element. In this case, the key function will be a lambda that extracts the 'first' key value from the dictionary.\nSecond, in case where two dictionaries have the same 'first' value, they should be sorted by 'second' key. The 'sorted' function provides a way to specify multiple sort keys by returning a tuple from the key function. Python will sort the tuples primarily by the first element, then the second element in case of ties, and so on.\nLet's solve:\ninput = [{'first': 12, 'second': 17}, {'first': -88, 'second': 99}, {'first': -90, 'second': -3}, {'first': 64, 'second': -24}, {'first': -23, 'second': -4}, {'first': -6, 'second': -62}, {'first': 2, 'second': 0}]\nWe will use the 'sorted' function with key function that returns tuple with 'first' and 'second' values of dictionary:\nresult = sorted(input, key=lambda d: (d['first'], d['second']))\nThe result should be the list sorted by 'first' key from smallest to largest, and by 'second' key in case of same 'first' values:\n[{'first': -90, 'second': -3}, {'first': -88, 'second': 99}, {'first': -23, 'second': -4}, {'first': -6, 'second': -62}, {'first': 2, 'second': 0}, {'first': 12, 'second': 17}, {'first': 64, 'second': -24}] \nThis result list satisfies all the conditions provided in the task - dictionaries with smaller 'first' key come before ones with larger 'first' key, and in case of a tie the 'second' key is used for ordering. Also, negative values come before positive ones."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Detailed Instructions: You will be given a sentence that describes a restaurant. You will also be given a few categories of information regarding that sentence. Your task is to fill each of the categories with the appropriate information from the sentenece.\nProblem:Sentence: The coffee shop, xname, is family-friendly and has prices less than 20. It is located near xnear in the city centre. ; Categories: price, location, familyFriendly\nSolution:\n\nAssistant: Step 1: Identify the information corresponding to each category in the sentence.\nPrice: The sentence mentions \"prices less than 20.\"\nLocation: The sentence mentions \"located near xnear in the city centre.\"\nFamily-friendly: The sentence mentions the coffee shop is \"family-friendly.\"\nStep 2: Fill in the categories with the appropriate information found in Step 1.\nPrice: Less than 20\nLocation: Near xnear in the city centre\nFamily-friendly: Yes\nStep 3: Present the completed information in the given categories format.\nSolution:\nCategories:\nPrice: Less than 20\nLocation: Near xnear in the city centre\nFamily-friendly: Yes"} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Six Nations rugby in Dublin with our experts\nDepart from your preferred airport – either London Heathrow or Manchester – bound for Dublin. On arrival you will be met by your Gullivers Sports Travel tour manager and transferred into central Dublin. Check in to your four-star hotel – likely to be the Doubletree by Hilton The Burlington – for two nights on a bed and breakfast basis.\nYou have this afternoon and the evening free to soak up the once-in-a-lifetime experience of spending St Patrick’s Day in Dublin. This used to be a relatively low-key occasion in the Irish capital but is now celebrated with aplomb throughout the city. There is a huge amount going on today, including the parade, which threads through Dublin’s streets, and as convivial an atmosphere as you could wish for in the city’s pubs and bars. You may wish to pack something emerald green into your baggage!\nToday you will enjoy an event put on exclusively for Telegraph readers, led by our rugby news correspondent Gavin Mairs. Over a Saturday brunch – to allow for a gentle start for those who have enjoyed their share of local hospitality the evening before – Gavin will give his insights into the way today’s match is shaping up, as well as leading a Q&A with two former England players (provisionally Ugo Monye and Simon Shaw). You will hear from a player’s perspective about building up for a big match, and what the approach of the two teams is likely to be – in short, the perfect way to whet your appetite for the fixture.\nYou will make your own way to Aviva Stadium, where you will take your seats and enjoy the build-up before watching what could well be the deciding match of the Six Nations championship.\nAfter taking breakfast in your hotel, you have a small amount of time to make the most of your final hours in Dublin before being transferred to the airport by luxury coach and your flight home.\nWhat was this article about?\n\nAssistant: This article is about a special trip for enjoying the Six Nations rugby game in Dublin. The trip includes flying from London Heathrow or Manchester to Dublin, staying in a four-star hotel, and celebrating St Patrick's Day in Dublin. You will also attend an exclusive event with rugby experts and former players who will share their insights and answer questions. Finally, you will watch a Six Nations rugby match at the Aviva Stadium before flying back home."} +{"text": "System: You are an AI assistant. You should describe the task and explain your answer. While answering a multiple choice question, first output the correct answer(s). Then explain why other answers are wrong. You might need to use additional knowledge to answer the question.\n\nUser: Options are:\n (a). operates off a reception stage in front a yellow curtain on a college field while spectators standon wople in the stands and applaud..\n (b). approaches a two sided weight as a woman watches from the side near a door..\n (c). performs with upper arm and shoulders strength, front and back with hand and head and legs, about eighty five degree backs and, arms and feet, extend forward, turn backwards, then grab and hold the weight daily..\n (d). , engage in arm wrestling with one another, usually in pairs..Choose from options above and answer: What most naturally follows?\nA man at a professional sports event, lifts a large and heavy weight using knees, back and arm strength while a woman in a suit watches from the sidelines. a man wearing a \" usa \" shirt, or team uniform\nAnswer:\n\nAssistant: (b). approaches a two sided weight as a woman watches from the side near a door..\nOption (b) is most likely the correct follow up to the scenario described. The reasoning behind this is that it mentions a man approaching a weight (which could be the \"large and heavy weight\" from the initial scenario), and a woman watching from the side (which could be the \"woman in a suit\").\nOption (a) is incorrect because the scenario is taking place on a college field, while in the original context, the event was described as a professional sports event. The settings are inconsistent.\nOption (c) is incorrect because it describes the physical actions of the athlete in an excessive and detailed manner that does not align with the succinct and focused narrative style seen in the original scenario.\nOption (d) is incorrect because it describes an entirely different scenario of arm wrestling. Instead, the original scenario denotes weightlifting, making this option unfit to naturally follow given the context."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: I would like to start this review by saying that I have been coming to Ken's Nails and Spa for a few years now for several different services and - in the past - have always had great service and results. I have always been greeted immediately and have never had to wait more than a couple minutes before I am taken back for services. It truly has been my go-to in the past and I am super sad that I am going to stop frequenting this business. I am not exactly sure what has changed but my last few visits have been anything but pleasant. First, a couple weeks ago I came in for my normal eyebrow wax. Usually I am in and out in a couple minutes with perfectly manicured brows. This time I left with my eyebrows on FIRE. I went to the Albertson's following my appointment and bought some cooling aloe gel for relief but nothing helped. The area ended up breaking out in blisters that last at least a week. I had several people ask me if I had suffered a burn. Yikes! it was not good and extremely painful. I let the bad eyebrow wax slide and thought maybe it was a fluke accident and went back last night for a mani-pedi. I feel bad because I am not sure the tech who was working on me really understood some of the things I was trying to communicate to her. At one point I asked her to file my nails slightly shorter and she said \"ok\" but then just moved on. That wasn't really the worst part though. During both the mani and pedi she repeatedly hurt me while removing my cuticles. I can be a wuss so I know sometimes it is just me being a baby but this was ridiculous. At one point I gasped and pulled my foot away and the nail tech next to her actually had to scold her and tell her to take it easy. Needless to say I left with one nail bleeding which didn't stop for over an hour and is still painful to the touch.\nChoose the sentiment of this review?\npick from the following. I. negative. II. positive.\n\nAssistant: The sentiment of this review is negative."} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Detailed Instructions: You are given a sentence in Portuguese. Your job is to translate the Portuguese sentence into Arabic.\nProblem:Nessa época eu não podia trabalhar na bolsa de valores, não podia conduzir um autocarro nem pilotar um avião.\nSolution:\n\nAssistant: في تلك الفترة لم أستطع العمل في سوق الأوراق المالية، ولم أستطع قيادة حافلة أو قيادة طائرة.\nIn this period, I couldn't work in the stock market, I couldn't drive a bus or fly a plane."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: less than a week after the florida supreme court tossed out a massive class-action lawsuit against big tobacco , lawyers across the state are gearing up for battle on behalf of individuals and families hurt by smoking .\nCan you generate a short summary of the above paragraph?\n\nAssistant: Following the Florida Supreme Court's dismissal of a major class-action lawsuit against big tobacco, attorneys throughout the state are preparing to represent individuals and families affected by smoking."} +{"text": "System: You are an AI assistant that follows instruction extremely well. Help as much as you can.\n\nUser: Detailed Instructions: In this task, you are given a context and four options. Each option is a suggested ending for the context. You should read the context and pick the best ending for the context. Please answer with \"A\", \"B\", \"C\", and \"D\". \nQ: A group of men on horses talk and as they ride away, the clown restarts his job in the ring. the clown (A) smolders up at the man in yellow who claps in response. (B) continues dancing on the horse and then flips the horse and wrists in mid air. (C) starts the advertisement and an introduction comes onto the screen for riding a horse. (D) gets injured again as he jumps over a bulls head and the other clowns help him up.\nA:\n\nAssistant: \"D\""} +{"text": "System: You are an AI assistant. Provide a detailed answer so user don’t need to search outside to understand the answer.\n\nUser: Answer the following question: Title: Surprisingly better than expected Amazon book Product review: The first part of the book to me was more juvenile, about the children, didn't have any idea where it was taking me .. but as you read on, there was a definite shift and a message for everyone whose ever been in this medical situation when you see how the advice given to Katie was understood, taken to heart and followed by her, still a very young girl, and how it affected her siblings and her entire family. She was to be admired. Turned out I discovered it became a delightful story that I am glad I took the time to get past the first few chapters. Would you say this review depicts the product in a flattering or unflattering light?\nAnswer:\n\nAssistant: This review depicts the product in a mostly flattering light. The reviewer begins with a slightly negative comment about the book's juvenile first part but goes on to praise the shift in the story, its message, and the impact on the characters as the book progresses. The reviewer ultimately describes the book as a delightful story worth reading and is glad they continued past the first few chapters. While the review starts with a mild criticism, the overall sentiment is positive and flattering towards the book."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Answer the following question: Title: only the beginning Review: Warlock is the first in a trilogy by author Oakley Hall, the second novel in the trilogy being Badlands, followed by Apaches. I was simply awed by the writing of Mr Hall, and the universal human truths he reminds the reader of. I can see that more than a few writers must have read Oakley Hall's novels, most especially Cormac Mccarthy. Warlock was published in 1958, and Badlands was at least 10 yrs later, followed by Apaches, which was at least another decade later. Mr Hall also does the fine Ambrose Bierce series of novels, and with a career spanning 5 decades, he is still underated and underapreciated by the general public. do yourself a favor and discover this most excellent writer. Is this product review negative?\nOptions:\n + Yes\n + No\nAnswer:\n\nAssistant: No"} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: What is a question about this article? If the question is unanswerable, say \"unanswerable\".\nIsrael, and The US Air Force, in conjunction with the members of NATO, has developed significant tactics for air defence suppression. Dedicated weapons such as anti-radiation missiles and advanced electronics intelligence and electronic countermeasures platforms seek to suppress or negate the effectiveness of an opposing air-defence system. It is an arms race; as better jamming, countermeasures and anti-radiation weapons are developed, so are better SAM systems with ECCM capabilities and the ability to shoot down anti-radiation missiles and other munitions aimed at them or the targets they are defending.\nIn agreement with NATO members, the US Air Force and which country has created tactics for air defence suppression?\n\nAssistant: Israel, in conjunction with the US Air Force and members of NATO, has created tactics for air defense suppression."} +{"text": "System: User will you give you a task with some instruction. Your job is follow the instructions as faithfully as you can. While answering think step-by-step and justify your answer.\n\nUser: Q: Given a sentence in the Japanese and Indonesian(Bahasa variant) language. Your task is check if the Bahasa Indonesia sentence is translation of Japanese. if the translation is correct than generate label \"Yes\", otherwise generate label \"No\".\nJapanese: 彼は「ここに集まっている忠実な信者に、行方不明者の無事な帰還、救助に命を懸けた男女や、この悲劇的な事故で失われた魂のために祈ることを求める」と言った。 \n Bahasa Indonesia: Dia banyak menegaskan posisi yang sama dalam percakapan telepon dengan Presiden Amerika Serikat, George W. Bush, hari ini.\nA:\n\nAssistant: To check if the Bahasa Indonesia sentence is a translation of the Japanese sentence, we need to have a general understanding of the meaning of both sentences and compare them.\nJapanese: 彼は「ここに集まっている忠実な信者に、行方不明者の無事な帰還、救助に命を懸けた男女や、この悲劇的な事故で失われた魂のために祈ることを求める」と言った。 \nTranslation (in English): He said, \"I ask the faithful believers gathered here to pray for the safe return of the missing, for the men and women who risked their lives in the rescue, and for the souls lost in this tragic accident.\"\nBahasa Indonesia: Dia banyak menegaskan posisi yang sama dalam percakapan telepon dengan Presiden Amerika Serikat, George W. Bush, hari ini.\nTranslation (in English): He strongly emphasized the same position in a phone conversation with the President of the United States, George W. Bush, today.\nComparing the meaning of both sentences, it is clear that the Bahasa Indonesia sentence is not a translation of the Japanese sentence.\nLabel: No"} +{"text": "System: Explain how you used the definition to come up with the answer.\n\nUser: Q: You are given a sentence in English. Your job is to translate the English sentence into Spanish.\nAnd last year I spoke at his memorial service and wished him well in this life.\nA:\n\nAssistant: Y el año pasado hablé en su servicio conmemorativo y le deseé lo mejor en esta vida.\nIn order to come up with the answer, I used the definition and knowledge of both languages. First, I identified the main parts and ideas of the sentence: \"last year,\" \"I spoke,\" \"his memorial service,\" \"wished him well,\" and \"in this life.\" Then, I translated each part of the sentence into Spanish, while ensuring that the grammar and structure are correct. \n- \"last year\" -> \"el año pasado\"\n- \"I spoke\" -> \"hablé\"\n- \"his memorial service\" -> \"su servicio conmemorativo\"\n- \"wished him well\" -> \"le deseé lo mejor\"\n- \"in this life\" -> \"en esta vida\"\nFinally, I combined these translations into a coherent sentence in Spanish, keeping the original meaning and context of the English sentence. So the final translation is: \"Y el año pasado hablé en su servicio conmemorativo y le deseé lo mejor en esta vida.\""} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Here is a premise:\nThe Symposium has been the forum for debate and discussion of a wide range of equal justice issues, from the mission of the civil legal assistance delivery system to the design and implementation of a number of specific initiatives aimed at unifying, expanding and improving the state's equal justice community.\nHere is a hypothesis:\nA range of equal justice issues have been debated.\nHere are the options: pick from the following. -- yes -- it is not possible to tell -- no\nIs it possible to conclude that if the premise is true, then so is the hypothesis?\n\nAssistant: Yes, it is possible to conclude that if the premise is true, then so is the hypothesis. \nThe premise states that The Symposium has been the forum for debate and discussion of a wide range of equal justice issues. This implies that there is a variety of topics covered under the umbrella of equal justice, and they have been discussed at this particular forum. \nThe hypothesis states that a range of equal justice issues have been debated. The implication here is that various equal justice topics have been discussed and debated, which aligns with the information provided in the premise.\nSince the premise and hypothesis are aligned and the information in the hypothesis is derived from the premise, we can conclude that if the premise is true, then the hypothesis is also true."} +{"text": "System: Explain how you used the definition to come up with the answer.\n\nUser: Given the task definition and input, reply with output. In this task, you are given a sentence in the English and Hindi language. Your task is check if the Hindi sentence is translation of English. if the translation is correct than generate label \"Yes\", otherwise generate label \"No\".\nEnglish: \"The typical incentive programs simply do not resonate in today's economy,\" Jesse Toprak, an analyst for Edmunds. \n Hindi: एडमंड्स के एक विश्लेषक जेसी टोप्राक ने कहा, \"सामान्य प्रोत्साहन कार्यक्रम केवल आज की अर्थव्यवस्था में प्रतिध्वनित नहीं होते हैं।\"\n\nAssistant: To determine if the Hindi sentence is a translation of the English sentence, I would first analyze and compare the context and meaning of both sentences. \nEnglish: \"The typical incentive programs simply do not resonate in today's economy,\" Jesse Toprak, an analyst for Edmunds.\nHindi: एडमंड्स के एक विश्लेषक जेसी टोप्राक ने कहा, \"सामान्य प्रोत्साहन कार्यक्रम केवल आज की अर्थव्यवस्था में प्रतिध्वनित नहीं होते हैं।\"\nBased on my understanding of the sentences, the Hindi sentence is indeed a translation of the English sentence. Both sentences mention that typical incentive programs do not resonate in today's economy and attribute the statement to Jesse Toprak, an analyst for Edmunds. Therefore, I would generate the label \"Yes.\""} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Summarize:\nTwenty-year-olds who started antiretroviral therapy in 2010 are projected to live 10 years longer than those first using it in 1996, it found.\nDoctors say that starting treatment early is crucial to achieve a long and healthy life.\nCharities say there are still too many people unaware they have the virus.\nThis is particularly true in the developing world, where the majority of HIV deaths occur because access to drugs is limited.\nThe study authors, from the University of Bristol, said the extraordinary success of HIV treatments was a result of newer drugs having fewer side effects and being better at preventing the virus from replicating in the body.\nIt is also more difficult for the virus to build up a resistance to the most recent drugs.\nImproved screening and prevention programmes and better treatment of health problems caused by HIV are thought to have helped, too.\nBut many people with HIV still do not live as long as expected, especially those infected through injecting drugs.\nAntiretroviral therapy involves a combination of three or more drugs which block the normal progress of HIV (human immunodeficiency virus).\nThey have been called \"one of the greatest public health success stories of the past 40 years\".\nJimmy Isaacs, 28, discovered he had been infected with HIV by a former partner nearly three years ago.\nHe takes three drugs once a day at 18:00 and will continue to do so for the rest of his life.\n\"My health is absolutely fine. I'm eating healthily and drinking healthily,\" he said.\n\"It doesn't impact on my job and hasn't impacted on my social life either.\"\nAlthough it took two changes of medication to find the right combination for him, he says he now has no side effects at all.\n\"I had heard a lot of bad stories about the drugs back in the '90s - but when I did some research, I realised the drugs had completely changed.\"\nNot all his employers have been supportive since his diagnosis and he says that is down to ignorance.\nHis current employer has given him time off to tour the country and speak to students and school pupils about HIV prevention and treatment.\nThe researchers looked at 88,500 people with HIV from Europe and North America who had been involved in 18 studies.\nThey based their life-expectancy predictions on death rates during the first three years of follow-up after drug treatment was started.\nThey found that fewer people who started treatment between 2008 and 2010 died during this period compared with those who began treatment between 1996 and 2007.\nThe expected age at death of a 20-year-old patient starting antiretroviral therapy (ART) after 2008, with a low viral load and after the first year of treatment, was 78 years - similar to the general population.\nDr Michael Brady, medical director at the Terrence Higgins Trust, said the study showed how much things had changed since the start of the HIV epidemic in the 1980s.\nBut he said it also meant people aged over 50 now represented one in three of all those living with HIV.\n\"As it stands, the healthcare, social care and welfare systems simply aren't ready to support the increasing numbers of people growing older with HIV.\n\"We need a new model of care to better integrate primary care with HIV specialist services, and we need a major shift in awareness and training around HIV and ageing, so that we're ready to help older people live well in later life,\" he said.\nProf Helen Stokes-Lampard, who chairs the Royal College of GPs, said: \"It's a tremendous medical achievement that an infection that once had such a terrible prognosis is now so manageable, and that patients with HIV are living significantly longer.\n\"We hope the results of this study go a long way to finally removing any remaining stigma associated with HIV, and ensuring that patients with HIV can live long and healthy lives without experiencing difficulties in gaining employment and - in countries where it is necessary - obtaining medical insurance.\"\nShe said steps were being taken to increase appropriate HIV testing by GPs.\nThe proportion of people with undiagnosed HIV has fallen steadily over the past 20 years.\nBut one in eight people with HIV is still thought to remain undiagnosed.\nGet news from the BBC in your inbox, each weekday morning\nSummary:\n\nAssistant: A study by the University of Bristol found that 20-year-olds who began antiretroviral therapy for HIV in 2010 are expected to live 10 years longer than those who started treatment in 1996. Newer drugs have fewer side effects and are more capable in preventing the virus from replicating in the body, with the most recent drugs also being more resistant. Improved screening and treatment options have also contributed to the rise in life expectancy. However, many people with HIV still do not live as long as expected, particularly those who contracted the virus through drug injections."} +{"text": "System: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\nUser: Premise: Other hands joined his to boost Anse.\nHypothesis: He let Anse fall to the floor. \nIs the hypothesis entailed by the premise?\nChoose your answer from: * yes; * it is not possible to tell; * no; And the answer is:\n\nAssistant: * no;\nAnd the answer is: no.\nReasoning: The premise states that \"other hands joined his to boost Anse,\" which suggests that they are supporting or helping Anse. The hypothesis, however, states that \"he let Anse fall to the floor,\" which contradicts the act of support or help mentioned in the premise."} +{"text": "User: Title: Flayderman's Guide to Antique American Firearms and Their Values Product review: I ordered this for my husband and it was just what he wanted. Could have used a more up-to-date printbut don't think it's available yet.Thanks. Would you say this review depicts the product in a flattering or unflattering light?\nThe answer to this question is:\n\nAssistant: flattering"} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Riding high from their two-straight road wins, the Saints returned home to the Louisiana Superdome for the first time since December 26, 2004 for a special Monday Night game against their fellow NFC South division rival, the Atlanta Falcons, in front of a jubilant sold-out crowd of 70,003 and ESPN's largest-ever audience, with about 10.8 million households viewing, at the time the second-largest cable audience in history (behind the 1993 North American Free Trade Agreement (NAFTA) debate between Al Gore and Ross Perot on CNN). On just the fourth overall play of the game, Safety Steve Gleason blocked a punt and DB Curtis Deloatch managed to land on the ball in the end zone for the Saints first score of the game. Former Saints kicker Morten Andersen helped Atlanta get its only score of the game, from a 26-yard field goal. From there, the Saints dominated the rest of the game. For the final score of the first quarter, WR Devery Henderson ran 11 yards for a touchdown on a reverse. Kicker John Carney would provide the scoring for the rest of the game, as he kicked two second-quarter field goals (a 37-yarder and a 51-yarder), and one third-quarter field goal (a 20-yarder) to essentially wrap up the game. From there, the Saints defense continued to put pressure on Falcons quarterback Michael Vick, as they sacked him five times, while limiting the running duo of Vick and Dunn to just a combined total of 101 rushing yards. With their dominating performance, the Saints would take the lead in the NFC South with a surprising 3-0 record. Green Day and U2 performed jointly on the field before the game, while the Goo Goo Dolls held a concert outside the Dome earlier in the evening.\nAnswer this question based on the article: How many yards did Morten Anderson kick?\n\nAssistant: Based on the article, former Saints kicker Morten Andersen helped Atlanta get its only score of the game from a 26-yard field goal."} +{"text": "System: You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n\nUser: Here's a complex question that requires someone to reason about the input, can you answer it? What film noir starring Lee Jung-jae shows the conflict between the police and the mob?\nA:\n\nAssistant: The film noir starring Lee Jung-jae that showcases the conflict between the police and the mob is \"New World\" (Korean title: 신세계; RR: Sinsegye), a South Korean crime thriller film released in 2013. Directed and written by Park Hoon-jung, the movie dives into the dark, tense, and twisted world of organized crime, specifically focusing on the complex power dynamics that govern the interactions between the mob and the police.\nThe film's plot revolves around an undercover police officer, Lee Ja-sung, played by Lee Jung-jae, who infiltrates one of the largest and most powerful crime syndicates in Korea, known as Goldmoon. Lee Ja-sung spends years working his way up the ranks and develops close relationships with key members of the organization, including its second-in-command, Jung Chung, played by Hwang Jung-min, and an ambitious, ruthless rising star in the syndicate, Lee Joong-goo, played by Park Sung-woong.\nAs the story unfolds, we see how Lee Ja-sung's unwavering loyalty to his police chief, Kang Hyung-chul, played by Choi Min-sik, complicates his mission and forces him to make choices that impact the lives of those around him. The film explores themes of loyalty, betrayal, power struggles, and the blurred lines between the worlds of law enforcement and organized crime.\nNew World has been praised for its gripping narrative, evocative cinematography, and exceptional performances, especially those of Lee Jung-jae and Hwang Jung-min. The film's portrayal of the intricate and often treacherous relationships between the main characters contributes to its complex exploration of the eternal conflict between the police and mob organizations.\nIn summary, the film noir \"New World\" starring Lee Jung-jae captures the conflict between the police and the mob through its powerful storytelling and engaging characters, offering a poignant and insightful look into the murky world of organized crime in South Korea."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: This transparent dialogue also guarantees the citizens of Bulgaria and Romania that the fundamental rights enshrined in the treaties are a reality.\nTranslate to German\n\nAssistant: Dieser durchsichtige Dialog stellt auch sicher, dass für die Bürgerinnen und Bürger von Bulgarien und Rumänien die grundlegenden Rechte, die in den Verträgen verankert sind, Wirklichkeit werden.\nLass mich das für dich einfach erklären: Das bedeutet, dass ein offenes Gespräch hilft, den Menschen in Bulgarien und Rumänien sicherzustellen, dass sie wirklich die wichtigen Rechte haben, die in den Abkommen geschrieben stehen."} +{"text": "System: You are an AI assistant, who knows every language and how to translate one language to another. Given a task, you explain in simple steps what the task is asking, any guidelines that it provides. You solve the task and show how you used the guidelines to solve the task.\n\nUser: Detailed Instructions: You are given a conversation between two people. 'Person1:' and 'Person2:' are used to separate their respective dialogues. You are required to assign a label 'formal' if there is an absence of emotion and a presence of questions anywhere within the conversation. If such pattern is not found assign the label 'informal'.\nProblem: Person1: Room service . Can I help you ? \n Person2:Yes , tomorrow is my friend's birthday.Could I arrange a birthday party for him through you ? \n Person1: Certainly . \n Person2:I want to order a large birthday cake , several kinds of cold dishes , pastries and fruits . \n Person1: Any drinks ? \n Person2:I've got whiskey and brandy but I need some beer and soda.That ' s all . \nSolution:\n\nAssistant: Step 1: Identify the task\nThe task is to read a conversation between two people and determine whether it is formal or informal by checking for the presence of emotion and questions.\nStep 2: Understand the guidelines\nIf there is an absence of emotion and the presence of questions, the conversation is considered formal. If this pattern is not found, it is considered informal.\nStep 3: Analyze the conversation\nPerson1: Room service. Can I help you?\nPerson2: Yes, tomorrow is my friend's birthday. Could I arrange a birthday party for him through you?\nPerson1: Certainly.\nPerson2: I want to order a large birthday cake, several kinds of cold dishes, pastries, and fruits.\nPerson1: Any drinks?\nPerson2: I've got whiskey and brandy but I need some beer and soda. That's all.\nStep 4: Check for emotions and questions\nIn this conversation, there are no expressions of emotions. We can see the presence of questions from both Person1 and Person2.\nStep 5: Assign the label\nSince there is an absence of emotion and the presence of questions, the conversation can be labeled as 'formal'.\nSolution: Formal"} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Orrin Grant Hatch (born March 22, 1934) is an American attorney and politician serving as the senior United States Senator for Utah who has been the President pro tempore of the United States Senate since 2015. Having been a senator since 1977, Hatch is the longest-serving Republican Senator in U.S. history. Hatch served as either the chair or ranking minority member of the Senate Judiciary Committee from 1993 to 2005. He previously served as chairman of the Senate Committee on Health, Education, Labor, and Pensions from 1981 to 1987 and currently serves as Chairman of the Senate Finance Committee as well as serving on the board of directors for the United States Holocaust Memorial Museum.\nHatch introduced the Ensuring Patient Access and Effective Drug Enforcement Act, narrowing the broad authority of the DEA to suspend drug \"manufacturers, distributors, and dispensers\". Hatch stated the bill was also written to protect patients from disruptions in the production and delivery of their prescription drugs: \"The fact that prescription drugs can be abused should not prevent patients from receiving the medications they need. This bill takes a balanced approach to the problem of prescription drug abuse by clarifying penalties for manufacturing or dispensing outside approved procedures while helping to ensure that supply chains to legitimate users remain intact\". The bill passed the Senate unanimously and Tom Marino passed a version of the bill in the House and was signed by President Barack Obama. Critics of the bill claim the new law fuels the opioid crisis by limiting the DEA's ability to halt production and distribution by predatory drug companies. DEA Chief Administrative Law Judge John J. Mulrooney II wrote in the Marquette Law Review: \"At a time when, by all accounts, opioid abuse, addiction and deaths were increasing markedly, this new law imposed a dramatic diminution of the agency's authority. It is now all but logically impossible for the DEA to suspend a drug company's operations for failing to comply with federal law.\" Donald Trump's Attorney General Jeff Sessions called the law \"dubious\" and joined 44 state attorneys general calling for \"repeal or amendment of the law to restore some of the DEA's authority.\" Jim Geldhof, a former DEA program manager whom spent 43 years with the DEA called the bill \"outrageous. It basically takes any kind of action DEA was going to do with a distributor or manufacturer as far as an immediate suspension off the table. And then the other part of that really infuriates me is that corrective action plan.\" Mulrooney compared the corrective action plan to one that would \"allow bank robbers to round up and return inkstained money and agree not to rob any more banks -- all before any of those wrongdoers actually admit fault and without any consequence that might deter such behavior in the future.\" Hatch responded to a Washington Post and 60 Minutes investigation into the bill by writing a Washington Post opinion article calling the investigation \"misleading\" and asking to \"leave conspiracy theories to Netflix\". Senator Sheldon Whitehouse, a co-sponsor of the senate bill, also defended the bill: \"This bill was drafted in consultation with the DEA to offer better guidance for companies working to safely and responsibly supply prescription drugs to pharmacies, and to promote better communication and certainty between companies and regulators.\" Republican Pat Toomey Pat Toomey expressed doubts that a conspiracy existed, but still suggested amending the bill: \"I'm a little surprised that it passed unanimously in both houses, was signed by President Obama and got no opposition from the DEA at the time. That's not the way controversial legislation usually ends up, but hey, if there's problems, then we ought to revisit them.\" Hatch received $177,000 in donations from the drug industry while pushing the bill through and has received $2,178,863 from the Pharmaceuticals/Health Products industry from 1989-2014 according to required filings by the Federal Election Commission.\nAsk a question about this article.\n\nAssistant: Orrin Hatch is a very important man in the United States government. He has been a senator, which means he helps make laws, for a very long time - since 1977! He is from the state of Utah and is a member of the Republican party. He has worked on many different committees that focus on things like health, education, and laws. He created a bill, which is like an idea for a new rule, to make sure people can get the medicine they need but also make sure those medicines are not abused. This bill became a law, but some people think it may not be the best solution and might make another problem, called the opioid crisis, worse. Orrin Hatch has also received a lot of money from companies that make medicines, which some people think might make him support rules that are good for those companies."} +{"text": "System: You are a helpful assistant, who always provide explanation. Think like you are answering to a five year old.\n\nUser: Write a subject line for this message:\nNick: It looks like Skilling will be able to attend the conference on May 15th. However, he will need to speak in the morning, because his flight leaves on the 15th at 3:50 p.m. Let us know if that will work. Thanks. \nSubject Line:\n\nAssistant: Subject Line: Skilling Can Come to the Conference, but Needs Morning Time to Talk"} diff --git a/finetune/lora/v6/demo/infctx.sh b/finetune/lora/v6/demo/infctx.sh new file mode 100644 index 0000000..127fd5a --- /dev/null +++ b/finetune/lora/v6/demo/infctx.sh @@ -0,0 +1,25 @@ +load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth' +proj_dir='/home/rwkv/JL/out_model/infctx' +data_file='/home/rwkv/JL/data/roleplay' + + +n_layer=24 +n_embd=2048 + +micro_bsz=8 +epoch_save=5 +epoch_steps=1000 +ctx_len=16384 +chunk_ctx=2048 + + +python train.py --load_model $load_model \ +--proj_dir $proj_dir --data_file $data_file \ +--data_type binidx --vocab_size 65536 \ +--ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ +--n_layer $n_layer --n_embd $n_embd \ +--pre_ffn 0 --head_qk 0 --lr_init 1e-4 --lr_final 1e-4 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ +--lora_load rwkv-0 --lora --lora_r 64 --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \ +--my_testing "x060" --dataload pad \ +--train_type infctx --chunk_ctx $chunk_ctx --fla --wandb infctx \ No newline at end of file diff --git a/finetune/lora/v6/fla/__init__.py b/finetune/lora/v6/fla/__init__.py new file mode 100644 index 0000000..b500e55 --- /dev/null +++ b/finetune/lora/v6/fla/__init__.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- + +from fla.layers import (ABCAttention, BasedLinearAttention, DeltaNet, + GatedLinearAttention, HGRN2Attention, LinearAttention, + MultiScaleRetention, ReBasedLinearAttention) +from fla.models import (ABCForCausalLM, ABCModel, DeltaNetForCausalLM, + DeltaNetModel, GLAForCausalLM, GLAModel, + HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM, + HGRNModel, LinearAttentionForCausalLM, + LinearAttentionModel, RetNetForCausalLM, RetNetModel, + RWKV6ForCausalLM, RWKV6Model, TransformerForCausalLM, + TransformerModel) +from fla.ops import (chunk_gla, chunk_retention, fused_chunk_based, + fused_chunk_gla, fused_chunk_retention) + +__all__ = [ + 'ABCAttention', + 'BasedLinearAttention', + 'DeltaNet', + 'HGRN2Attention', + 'GatedLinearAttention', + 'LinearAttention', + 'MultiScaleRetention', + 'ReBasedLinearAttention', + 'ABCForCausalLM', + 'ABCModel', + 'DeltaNetForCausalLM', + 'DeltaNetModel', + 'HGRNForCausalLM', + 'HGRNModel', + 'HGRN2ForCausalLM', + 'HGRN2Model', + 'GLAForCausalLM', + 'GLAModel', + 'LinearAttentionForCausalLM', + 'LinearAttentionModel', + 'RetNetForCausalLM', + 'RetNetModel', + 'RWKV6ForCausalLM', + 'RWKV6Model', + 'TransformerForCausalLM', + 'TransformerModel', + 'chunk_gla', + 'chunk_retention', + 'fused_chunk_based', + 'fused_chunk_gla', + 'fused_chunk_retention' +] + +__version__ = '0.1' diff --git a/finetune/lora/v6/fla/layers/__init__.py b/finetune/lora/v6/fla/layers/__init__.py new file mode 100644 index 0000000..cb8e442 --- /dev/null +++ b/finetune/lora/v6/fla/layers/__init__.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + +from .abc import ABCAttention +from .based import BasedLinearAttention +from .delta_net import DeltaNet +from .gla import GatedLinearAttention +from .hgrn import HGRNAttention +from .hgrn2 import HGRN2Attention +from .linear_attn import LinearAttention +from .multiscale_retention import MultiScaleRetention +from .rebased import ReBasedLinearAttention +from .rwkv6 import RWKV6Attention + +__all__ = [ + 'ABCAttention', + 'BasedLinearAttention', + 'DeltaNet', + 'GatedLinearAttention', + 'HGRNAttention', + 'HGRN2Attention', + 'LinearAttention', + 'MultiScaleRetention', + 'ReBasedLinearAttention', + 'RWKV6Attention' +] diff --git a/finetune/lora/v6/fla/layers/abc.py b/finetune/lora/v6/fla/layers/abc.py new file mode 100644 index 0000000..4f4a9cc --- /dev/null +++ b/finetune/lora/v6/fla/layers/abc.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import warnings +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from transformers.cache_utils import Cache + +from fla.modules import (FusedRMSNormSwishGate, RMSNorm, RotaryEmbedding, + ShortConvolution) +from fla.modules.activations import swiglu, swish +from fla.modules.convolution import proj_then_conv1d +from fla.ops.abc.chunk import chunk_abc + + +class ABCAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 1024, + expand_k: float = 0.5, + expand_v: float = 1.0, + num_heads: int = 4, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + share_conv_kernel: bool = True, + num_slots: Optional[int] = None, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_low_rank_dim: int = 16, + gate_logit_normalizer: int = 16, + use_input_gate: bool = False, + use_output_gate: bool = True, + use_norm: bool = True, + clamp_min: Optional[float] = -32, + clamp_max: Optional[float] = 32, + layer_idx: Optional[int] = None, + **kwargs + ) -> ABCAttention: + super().__init__() + + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.key_dim = int(self.hidden_size * self.expand_k) + self.value_dim = int(self.hidden_size * self.expand_v) + self.head_k_dim = self.key_dim // self.num_heads + self.head_v_dim = self.value_dim // self.num_heads + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.share_conv_kernel = share_conv_kernel + + self.gate_low_rank_dim = gate_low_rank_dim + self.gate_logit_normalizer = gate_logit_normalizer + + self.use_input_gate = use_input_gate + self.use_output_gate = use_output_gate + self.use_norm = use_norm + + if num_slots is None: + num_slots = self.head_k_dim + self.num_slots = num_slots + + self.norm_eps = norm_eps + + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.layer_idx = layer_idx + + if layer_idx is None: + warnings.warn( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) + + if use_output_gate: + self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False) + self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + if use_short_conv: + self.conv_size = conv_size + if share_conv_kernel: + self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') + else: + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') + + if self.use_norm: + if self.use_output_gate: + self.g_norm = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps) + else: + self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps) + + if self.use_rope: + self.rotary = RotaryEmbedding(self.head_k_dim) + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + + if self.use_short_conv: + if self.share_conv_kernel: + hidden_states = self.h_conv1d(hidden_states) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + else: + q = proj_then_conv1d(hidden_states, self.q_proj.weight, self.q_conv1d.weight, self.q_conv1d.bias) + k = proj_then_conv1d(hidden_states, self.k_proj.weight, self.k_conv1d.weight, self.k_conv1d.bias) + v = proj_then_conv1d(hidden_states, self.v_proj.weight, self.v_conv1d.weight, self.v_conv1d.bias) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if self.use_input_gate: + q, k, v = map(lambda x: swish(x), (q, k, v)) + + if self.use_rope: + q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) + k = rearrange(k, '... (h d) -> ... h d', h=self.num_heads) + seqlen_offset = 0 + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + q, k = self.rotary(q, k, seqlen_offset) + q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads) + k = rearrange(k, 'b n h d -> b h n d', h=self.num_heads) + else: + q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) + k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads) + v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads) + + # [batch_size, n_heads, seq_len, num_slots] + s = rearrange(self.s_proj(hidden_states), 'b t (h m) -> b h t m', h=self.num_heads) + s = s.clamp_(self.clamp_min, self.clamp_max) + + last_state = past_key_values[self.layer_idx] if use_cache else None + o, last_state = chunk_abc(q, k, v, s, initial_state=last_state, output_final_state=use_cache) + if past_key_values is not None and last_state is not None: + past_key_values.update(last_state, self.layer_idx, q.shape[2]) + + o = rearrange(o, 'b h t d -> b t h d') + if self.use_norm and not self.use_output_gate: + o = self.g_norm(o) + elif self.use_output_gate: + g = rearrange(self.g_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_heads) + o = self.g_norm(o, g) if self.use_norm else swiglu(g, o) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + + return o, None, past_key_values + + def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: + param = next(self.parameters()) + state = tuple() + if self.use_short_conv: + state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) + state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots), + param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim)) + return state + + def state_size(self, sequence_length: int = 2048): + return self.num_heads * self.key_dim * self.head_v_dim diff --git a/finetune/lora/v6/fla/layers/based.py b/finetune/lora/v6/fla/layers/based.py new file mode 100644 index 0000000..bed0c16 --- /dev/null +++ b/finetune/lora/v6/fla/layers/based.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- + +""" +Linear attention in Based. +https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py +""" + +import torch +import torch.nn as nn +from einops import rearrange + +from fla.modules.feature_map import TaylorFeatureMap +from fla.ops.based import parallel_based +from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn + + +class BasedLinearAttention(nn.Module): + def __init__( + self, + hidden_size: int, + l_max: int = 2048, + feature_dim: int = 16, + num_key_value_heads: int = 12, + num_heads: int = 12, + feature_name: str = "taylor_exp", + eps: float = 1e-12, + causal: bool = True, + mode: str = "parallel", + ): + super().__init__() + self.hidden_size + self.l_max = l_max + self.mode = mode + assert self.mode in ["fused_chunk", "parallel", 'chunk'] + + # linear attention + self.feature_name = feature_name + self.feature_dim = feature_dim + self.num_key_value_heads = num_key_value_heads + self.num_heads = num_heads + self.head_dim = self.hidden_size // self.num_key_value_heads + self.causal = causal + + self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.dropout = nn.Identity() + self.feature_map = TaylorFeatureMap(feature_dim) + self.eps = eps + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward(self, hidden_states: torch.Tensor, **kwargs): + mode = self.mode + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) + if mode == "fused_chunk": + q, k = self.feature_map(q), self.feature_map(k) + o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1) + elif mode == 'chunk': + q, k = self.feature_map(q), self.feature_map(k) + o = chunk_linear_attn(q, k, v, normalize=True, scale=1) + elif mode == 'parallel': + assert q.shape[-1] <= 128 + o = parallel_based(q, k, v, True, True) + o = rearrange(o, "b h l d -> b l (h d)") + o = self.o_proj(o) + o = self.dropout(o) + return o + + # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 + + def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): + """ + x (torch.Tensor): tensor of shape (b, d, l) + y (torch.Tensor): tensor of shape (b, d, l) + """ + # hidden_states = hidden_states.transpose(1, 2) + b, l, _ = hidden_states.size() + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + + q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) + k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2) + v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # Linear attention + q, k = self.feature_map(q), self.feature_map(k) + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + + # Compute attention + if self.causal: + y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) + else: + y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) + y = rearrange(y, 'b h l d -> b l (h d)') + y = self.o_proj(y.to(hidden_states.dtype)) + y = self.dropout(y) + return y.to(hidden_states.dtype) + + +if __name__ == '__main__': + batch = 4 + seq_len = 1024 + hidden_size = 1024 + dtype = torch.float32 + x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True) + dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda() + model = BasedLinearAttention(hidden_size, mode='chunk').to(dtype).cuda() + y = model(x) + y.backward(dy, retain_graph=True) + x_grad, x.grad = x.grad, None + y2 = model.forward_reference(x) + y2.backward(dy) + assert y.allclose(y2, 0, 1e-4), breakpoint() + assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint() + print("Pass") diff --git a/finetune/lora/v6/fla/layers/delta_net.py b/finetune/lora/v6/fla/layers/delta_net.py new file mode 100644 index 0000000..194b18a --- /dev/null +++ b/finetune/lora/v6/fla/layers/delta_net.py @@ -0,0 +1,254 @@ +# -*- coding: utf-8 -*- + +# Sect4.2 of Linear Transformers Are Secretly Fast Weight Programmers https://arxiv.org/abs/2102.11174 + + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from transformers.cache_utils import Cache + + +from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution, LayerNorm +from fla.modules.rotary import RotaryEmbedding +from fla.ops.delta_rule import (fused_chunk_delta_rule, + fused_recurrent_linear_attn_delta_rule, + chunk_delta_rule) +from torch.nn import functional as F + + +def simple_norm(x): + return (F.normalize(x, dim=-1) * x.shape[-1] ** 0.5).to(x) + + +# @torch.jit.script +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + + +# @torch.jit.script +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +# @torch.jit.script +def elu_norm(x): + dtype = x.dtype + x = F.elu(x, 1., False) + 1. + return (x / x.sum(-1, keepdim=True)).to(dtype) + + + + +# https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1 +class DeltaNet(nn.Module): + def __init__( + self, + d_model: int = None, + hidden_size: int = 1024, + expand_k: float = 1.0, + expand_v: float = 1.0, + num_heads: int = 4, + mode: str = 'fused_chunk', + chunk_size: int = 16, + use_beta: bool = True, + use_gate: bool = True, + use_rope: bool = False, + use_output_norm: bool = True, + use_elu: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + share_conv_kernel: bool = False, + layer_idx: int = None, + qk_activation: str = 'silu', + qk_norm: str = None, + save_memory: str = False, + **kwargs + ) -> DeltaNet: + super().__init__() + self.mode = mode + self.qk_activation = qk_activation + self.qk_norm = qk_norm + assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] + assert self.qk_norm in ['l2', 'sum'] + if d_model is not None: + hidden_size = d_model + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.chunk_size = chunk_size + self.use_gate = use_gate + self.use_output_norm = use_output_norm + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.share_conv_kernel = share_conv_kernel + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.head_qk_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.layer_idx = layer_idx + + self.silu = torch.nn.SiLU() + + assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + self.use_beta = use_beta + self.use_elu = use_elu + if self.use_beta: + self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + if use_short_conv: + self.conv_size = conv_size + if share_conv_kernel: + self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation=None) + else: + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None) + self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu' if qk_activation == 'silu' else None) + self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + if self.use_gate: + self.norm = FusedRMSNormSwishGate(self.head_v_dim) + else: + self.norm = RMSNorm(self.head_v_dim) + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + + # change to inference mode. + mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode + last_state = past_key_values[self.layer_idx] if use_cache else None + + if attention_mask is not None: + if attention_mask.shape[-1] != hidden_states.shape[-2]: + attention_mask = attention_mask[:, -1:] + + if self.use_short_conv: + conv_state = last_state[0] if use_cache else None + if self.share_conv_kernel: + # conv state is updated inplace + hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + else: + conv_state_q = last_state[0] if use_cache else None + conv_state_k = last_state[1] if use_cache else None + conv_state_v = last_state[2] if use_cache else None + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + q = self.q_proj(hidden_states) + q = self.q_conv1d(q, attention_mask, conv_state_q) + k = self.k_conv1d(k, attention_mask, conv_state_k) + v = self.v_conv1d(v, attention_mask, conv_state_v) + else: + q = (self.q_proj(hidden_states)) + k = (self.k_proj(hidden_states)) + v = self.silu(self.v_proj(hidden_states)) + + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask.unsqueeze(-1)) + + q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) + + if self.qk_activation != 'silu': + if self.qk_activation == 'relu': + q, k = q.relu(), k.relu() + elif self.qk_activation == 'elu': + q, k = elu_p1(q), elu_p1(k) + elif self.qk_activation == 'identity': + pass + else: + raise NotImplementedError + + if self.qk_norm is not None: + if self.qk_norm == 'l2': + k = torch.nn.functional.normalize(k, dim=-1, p=2).to(v) #auto mixed precision type transfer is annoying. + q = torch.nn.functional.normalize(q, dim=-1, p=2).to(v) + elif self.qk_norm == 'sum': + q = sum_norm(q).to(v) + k = sum_norm(k).to(v) + + if self.use_beta: + beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() + else: + beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) + state = past_key_values[self.layer_idx][-1] if use_cache else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_linear_attn_delta_rule(q, k, v, beta, state, output_final_state=use_cache) + elif mode == 'fused_chunk': + assert self.chunk_size in [16, 32, 64] + o, recurrent_state = fused_chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache) + elif mode == 'chunk': + assert self.chunk_size in [16, 32, 64] + o, recurrent_state = chunk_delta_rule(q, k, v, beta, self.chunk_size, state, output_final_state=use_cache) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + if self.use_short_conv: + if self.share_conv_kernel: + state = (conv_state, recurrent_state) + else: + state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state) + else: + state = (recurrent_state,) + past_key_values.update(state, self.layer_idx) + + o = rearrange(o, 'b h l d -> b l h d') + if self.use_gate: + g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) + o = self.norm(o, g) + else: + o = self.norm(o) + o = rearrange(o, 'b l h d -> b l (h d)') + o = self.o_proj(o) + + return o, None, past_key_values + + def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: + param = next(self.parameters()) + state = tuple() + if self.use_short_conv: + if self.share_conv_kernel: + state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) + else: + # for q/k/v each + state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), + param.new_zeros(batch_size, self.key_dim, self.conv_size), + param.new_zeros(batch_size, self.value_dim, self.conv_size)) + state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),) + return state \ No newline at end of file diff --git a/finetune/lora/v6/fla/layers/gated_abc.py b/finetune/lora/v6/fla/layers/gated_abc.py new file mode 100644 index 0000000..e1bf4fe --- /dev/null +++ b/finetune/lora/v6/fla/layers/gated_abc.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import warnings +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers.cache_utils import Cache + +from fla.modules import (FusedRMSNormSwishGateLinear, RMSNormLinear, + RotaryEmbedding, ShortConvolution) +from fla.modules.activations import ACT2FN, swiglu_linear, swish +from fla.ops.abc.chunk_gate import chunk_gated_abc + + +class GatedABCAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 1024, + expand_k: float = 1., + expand_v: float = 1., + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + share_conv_kernel: bool = True, + num_slots: Optional[int] = None, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_low_rank_dim: Optional[int] = None, + gate_logit_normalizer: int = 16, + feature_map: str = 'swish', + use_rope: bool = False, + use_output_gate: bool = False, + use_norm: bool = True, + layer_idx: Optional[int] = None, + **kwargs + ) -> GatedABCAttention: + super().__init__() + + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + self.head_k_dim = self.key_dim // self.num_heads + self.head_v_dim = self.value_dim // self.num_heads + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.share_conv_kernel = share_conv_kernel + + if gate_low_rank_dim is None: + gate_low_rank_dim = self.hidden_size // 16 + self.gate_low_rank_dim = gate_low_rank_dim + self.gate_logit_normalizer = gate_logit_normalizer + + self.feature_map = feature_map + self.use_rope = use_rope + self.use_output_gate = use_output_gate + self.use_norm = use_norm + + if num_slots is None: + num_slots = self.head_k_dim + self.num_slots = num_slots + + self.layer_idx = layer_idx + + if layer_idx is None: + warnings.warn( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False) + self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False) + + if use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + if share_conv_kernel: + self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') + else: + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') + + if self.use_norm: + if self.use_output_gate: + self.g_norm = FusedRMSNormSwishGateLinear(self.hidden_size, elementwise_affine, norm_eps) + else: + self.g_norm = RMSNormLinear(self.hidden_size, elementwise_affine, norm_eps) + self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + if self.use_rope: + self.rotary = RotaryEmbedding(self.head_k_dim) + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + + last_state = past_key_values[self.layer_idx] if use_cache else None + if self.use_short_conv: + conv_state = last_state[0] if use_cache else None + if self.share_conv_kernel: + # conv state is updated inplace + hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + else: + conv_state_q = last_state[0] if use_cache else None + conv_state_k = last_state[1] if use_cache else None + conv_state_v = last_state[2] if use_cache else None + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + q = self.q_conv1d(q, attention_mask, conv_state_q) + k = self.k_conv1d(k, attention_mask, conv_state_k) + v = self.v_conv1d(v, attention_mask, conv_state_v) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + f = self.f_proj(hidden_states) + + if self.use_rope: + q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) + k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads) + seqlen_offset = 0 + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + q, k = self.rotary(q, k, seqlen_offset) + q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads) + k = rearrange(k, 'b n h d -> b h n d', h=self.num_kv_heads) + else: + q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) + if self.num_kv_groups > 1: + k = repeat(k, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups) + else: + k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_kv_heads) + if self.num_kv_groups > 1: + v = repeat(v, 'b n (h d) -> b (h g) n d', h=self.num_kv_heads, g=self.num_kv_groups) + f = repeat(f, 'b n (h m) -> b (h g) n m', h=self.num_kv_heads, g=self.num_kv_groups) + else: + v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_kv_heads) + f = rearrange(f, 'b n (h m) -> b h n m', h=self.num_kv_heads) + + if self.feature_map is not None: + q, k, v = map(lambda x: ACT2FN[self.feature_map](x), (q, k, v)) + f = F.logsigmoid(f) / self.gate_logit_normalizer + s = (1 - f.exp()).to(f.dtype) + # dealing with left-padding + if attention_mask is not None: + s = s.mul_(attention_mask.view(attention_mask.shape[0], 1, -1, 1)) + v = v.mul_(attention_mask.view(attention_mask.shape[0], 1, -1, 1)) + + recurrent_state = last_state[-2:] if use_cache else None + o, recurrent_state = chunk_gated_abc(q, k, v, s, f, + initial_state=recurrent_state, + output_final_state=use_cache) + if past_key_values is not None: + if self.use_short_conv: + if self.share_conv_kernel: + last_state = (conv_state,) + recurrent_state + else: + last_state = (conv_state_q, conv_state_k, conv_state_v) + recurrent_state + else: + last_state = recurrent_state + past_key_values.update(last_state, self.layer_idx, q.shape[2]) + + o = rearrange(o, 'b h t d -> b t (h d)') + if self.use_norm and not self.use_output_gate: + o = swish(o) + o = self.g_norm(o, self.o_proj.weight, self.o_proj.bias) + elif self.use_output_gate and not self.use_norm: + o = swiglu_linear(self.g_proj(hidden_states), o, self.o_proj.weight, self.o_proj.bias) + elif self.use_output_gate and self.use_norm: + o = self.g_norm(o, self.g_proj(hidden_states), self.o_proj.weight, self.o_proj.bias) + else: + o = self.o_proj(o) + return o, None, past_key_values + + def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: + param = next(self.parameters()) + state = tuple() + if self.use_short_conv: + if self.share_conv_kernel: + state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) + else: + state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), + param.new_zeros(batch_size, self.key_dim, self.conv_size), + param.new_zeros(batch_size, self.value_dim, self.conv_size)) + state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots), + param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim)) + return state + + def state_size(self, sequence_length: int = 2048): + return self.num_heads * self.key_dim * self.head_v_dim diff --git a/finetune/lora/v6/fla/layers/gla.py b/finetune/lora/v6/fla/layers/gla.py new file mode 100644 index 0000000..8257196 --- /dev/null +++ b/finetune/lora/v6/fla/layers/gla.py @@ -0,0 +1,268 @@ +# -*- coding: utf-8 -*- + + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers.cache_utils import Cache + +from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution +from fla.modules.activations import ACT2FN +from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla + + +class GatedLinearAttention(nn.Module): + r""" + The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa + + Args: + mode (str, Optional): + Which GLA kernel to use. + Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`. + Default: `chunk`. + hidden_size (int, Optional): + The hidden size of the input. Default: 1024. + expand_k (float, Optional): + The expansion ratio for the key dim. Default: 0.5. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 1.0. + num_heads (int, Optional): + The number of heads. Default: 4. + num_kv_heads (int, Optional): + The number of key/value heads, used for MQA. Default: None. + feature_map (str, Optional): + Feature map function applied to queries/keys. Default: None. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `False`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + share_conv_kernel (bool, Optional): + Whether to apply convolutions berfore q/k/v mapping, only taking effects when `use_short_conv`. Default: `True`. + use_output_gate (bool, Optional): + Whether to use output gate. Default: `True`. + gate_fn (str, Optional): + The activation function for the output gate. Default: `swish`. + elementwise_affine (bool, Optional): + If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. + norm_eps (float, Optional): + The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. + gate_logit_normalizer (int, Optional): + The normalizer for the gate logits, appied after `logsigmoid`. Default: 16. + gate_low_rank_dim (int, Optional): + The low rank dim for the gate projection. Default: 16. + clamp_min (float, Optional): + The minimum value for the gate logits. Default: None. + fuse_norm (bool, Optional): + Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. + layer_idx (int, Optional): + The index of the layer. Default: None. + """ + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 0.5, + expand_v: float = 1.0, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + share_conv_kernel: bool = True, + use_output_gate: bool = True, + gate_fn: str = 'swish', + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_logit_normalizer: int = 16, + gate_low_rank_dim: int = 16, + clamp_min: Optional[float] = None, + fuse_norm: bool = True, + layer_idx: int = None, + ) -> GatedLinearAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.share_conv_kernel = share_conv_kernel + self.use_output_gate = use_output_gate + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + self.clamp_min = clamp_min + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_qk_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) + if self.use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + if share_conv_kernel: + self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') + else: + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') + + self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False), + nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True)) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + if gate_fn == 'swish' and fuse_norm and use_output_gate: + self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps) + self.fuse_norm_and_gate = True + else: + self.fuse_norm_and_gate = False + self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps) + self.gate_fn = ACT2FN[gate_fn] + + self.gate_logit_normalizer = gate_logit_normalizer + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode + + last_state = past_key_values[self.layer_idx] if use_cache else None + if self.use_short_conv: + conv_state = last_state[0] if use_cache else None + if self.share_conv_kernel: + # conv state is updated inplace + hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + else: + conv_state_q = last_state[0] if use_cache else None + conv_state_k = last_state[1] if use_cache else None + conv_state_v = last_state[2] if use_cache else None + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + q = self.q_conv1d(q, attention_mask, conv_state_q) + k = self.k_conv1d(k, attention_mask, conv_state_k) + v = self.v_conv1d(v, attention_mask, conv_state_v) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + gk = self.gk_proj(hidden_states) + + if self.feature_map_fn is not None: + q, k = map(self.feature_map_fn, (q, k)) + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask.unsqueeze(-1)) + q = rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads) + if self.num_kv_groups > 1: + k, v, gk = (repeat(x, 'b l (h d) -> b (h g) l d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v, gk)) + else: + k, v, gk = (rearrange(x, 'b l (h d) -> b h l d', h=self.num_kv_heads) for x in (k, v, gk)) + gk = F.logsigmoid(gk) / self.gate_logit_normalizer + + if self.clamp_min is not None: + gk = torch.clamp_min(gk, self.clamp_min) + + recurrent_state = last_state[-1] if use_cache else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache) + elif mode == 'fused_chunk': + o, recurrent_state = fused_chunk_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache) + elif mode == 'chunk': + o, recurrent_state = chunk_gla(q, k, v, gk, initial_state=recurrent_state, output_final_state=use_cache) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + if self.use_short_conv: + if self.share_conv_kernel: + last_state = (conv_state, recurrent_state) + else: + last_state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state) + else: + last_state = (recurrent_state,) + past_key_values.update(last_state, self.layer_idx, q.shape[2]) + + o = rearrange(o, 'b h l d -> b l h d') + if self.use_output_gate: + g = self.g_proj(hidden_states) + if self.fuse_norm_and_gate: + g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads) + o = self.g_norm_swish_gate(o, g) + o = rearrange(o, 'b l h d -> b l (h d)') + else: + o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)') + o = o * self.gate_fn(g) + else: + o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)') + o = self.o_proj(o) + + return o, None, past_key_values + + def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: + param = next(self.parameters()) + state = tuple() + if self.use_short_conv: + if self.share_conv_kernel: + state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) + else: + state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), + param.new_zeros(batch_size, self.key_dim, self.conv_size), + param.new_zeros(batch_size, self.value_dim, self.conv_size)) + state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),) + return state + + def state_size(self, **kwargs) -> int: + state_size = self.key_dim * self.head_v_dim + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/finetune/lora/v6/fla/layers/hgrn.py b/finetune/lora/v6/fla/layers/hgrn.py new file mode 100644 index 0000000..b852d29 --- /dev/null +++ b/finetune/lora/v6/fla/layers/hgrn.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- + +# "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823] + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers.cache_utils import Cache + +from fla.modules import FusedRMSNormSwishGate, ShortConvolution +from fla.modules.activations import swiglu +from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn + + +class HGRNAttention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 1, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + share_conv_kernel: bool = True, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None + ) -> HGRNAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.num_heads = num_heads + self.expand_ratio = expand_ratio + self.input_dim = int(hidden_size * expand_ratio) + self.head_dim = self.input_dim // self.num_heads + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.share_conv_kernel = share_conv_kernel + + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.hidden_size % num_heads == 0, f"hidden size must be divisible by num_heads of {num_heads}" + + self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + if share_conv_kernel: + self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') + else: + self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') + self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') + self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') + + self.g_norm = FusedRMSNormSwishGate(self.input_dim, elementwise_affine, norm_eps) + self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode + + last_state = past_key_values[self.layer_idx] if use_cache else None + if self.use_short_conv: + conv_state = last_state[0] if use_cache else None + if self.share_conv_kernel: + # conv state is updated inplace + hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) + i = self.i_proj(hidden_states) + f = self.f_proj(hidden_states) + else: + conv_state_i = last_state[2] if use_cache else None + conv_state_f = last_state[1] if use_cache else None + i = self.i_conv1d(self.i_proj(hidden_states), attention_mask, conv_state_i) + f = self.f_conv1d(self.f_proj(hidden_states), attention_mask, conv_state_f) + else: + i = self.i_proj(hidden_states) + f = self.f_proj(hidden_states) + + # the lower bound for the first layer is zero + if lower_bound is None or self.layer_idx == 0: + i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f) + else: + g = lower_bound + (1 - lower_bound) * f.sigmoid() + i, f = swiglu(i, 1 - g), g.log() + + # dealing with left-padding + if attention_mask is not None: + i = i.mul_(attention_mask.unsqueeze(-1)) + i, f = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (i, f)) + + recurrent_state = last_state[-1] if use_cache else None + if mode == 'chunk': + o, recurrent_state = chunk_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + if self.use_short_conv: + if self.share_conv_kernel: + last_state = (conv_state, recurrent_state) + else: + last_state = (conv_state_i, conv_state_f, recurrent_state) + else: + last_state = (recurrent_state,) + past_key_values.update(last_state, self.layer_idx, i.shape[2]) + + o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)')) + o = self.o_proj(o) + + return o, None, past_key_values + + def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: + param = next(self.parameters()) + state = tuple() + if self.use_short_conv: + if self.share_conv_kernel: + state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) + else: + state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size), + param.new_zeros(batch_size, self.hidden_size, self.conv_size), + param.new_zeros(batch_size, self.hidden_size, self.conv_size)) + state += (param.new_zeros(batch_size, self.num_heads, self.head_dim),) + return state + + def state_size(self, **kwargs) -> int: + state_size = self.hidden_size + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/finetune/lora/v6/fla/layers/hgrn2.py b/finetune/lora/v6/fla/layers/hgrn2.py new file mode 100644 index 0000000..19a3da6 --- /dev/null +++ b/finetune/lora/v6/fla/layers/hgrn2.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- + +# "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904] + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers.cache_utils import Cache + +from fla.modules import RMSNorm, ShortConvolution +from fla.modules.activations import swish +from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla + + +class HGRN2Attention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + share_conv_kernel: bool = True, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None + ) -> HGRN2Attention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + + if expand_ratio is None and num_heads is not None: + expand_ratio = hidden_size // num_heads + elif expand_ratio is not None and num_heads is None: + num_heads = hidden_size // expand_ratio + else: + raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.") + self.num_heads = num_heads + self.expand_ratio = expand_ratio + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.share_conv_kernel = share_conv_kernel + + self.forget_dim = int(self.num_heads * self.expand_ratio) + self.input_dim = hidden_size + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`." + assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}" + assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}" + + self.head_f_dim = self.expand_ratio + self.head_i_dim = self.hidden_size // num_heads + + self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False) + self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False) + self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + if share_conv_kernel: + self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') + else: + self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu') + self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu') + self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') + + self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, norm_eps) + self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode + + last_state = past_key_values[self.layer_idx] if use_cache else None + if self.use_short_conv: + conv_state = last_state[0] if use_cache else None + if self.share_conv_kernel: + # conv state is updated inplace + hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) + q = self.q_proj(hidden_states) + f = self.f_proj(hidden_states) + i = self.i_proj(hidden_states) + else: + conv_state_q = last_state[0] if use_cache else None + conv_state_f = last_state[1] if use_cache else None + conv_state_i = last_state[2] if use_cache else None + q = self.q_proj(hidden_states) + f = self.f_proj(hidden_states) + i = self.i_proj(hidden_states) + q = self.q_conv1d(q, attention_mask, conv_state_q) + f = self.f_conv1d(f, attention_mask, conv_state_f) + i = self.i_conv1d(i, attention_mask, conv_state_i) + else: + q = self.q_proj(hidden_states) + f = self.f_proj(hidden_states) + i = self.i_proj(hidden_states) + + # dealing with left-padding + if attention_mask is not None: + i = i.mul_(attention_mask.unsqueeze(-1)) + + q = swish(q) + # the lower bound for the first layer is zero + if lower_bound is None or self.layer_idx == 0: + k, g = 1 - f.sigmoid(), F.logsigmoid(f) + else: + g = lower_bound + (1 - lower_bound) * f.sigmoid() + k, g = 1 - g, g.log() + q, k, i, g = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, i, g)) + + recurrent_state = last_state[-1] if use_cache else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache) + elif mode == 'fused_chunk': + o, recurrent_state = fused_chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache) + elif mode == 'chunk': + o, recurrent_state = chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + if self.use_short_conv: + if self.share_conv_kernel: + last_state = (conv_state, recurrent_state) + else: + last_state = (conv_state_q, conv_state_f, conv_state_i, recurrent_state) + else: + last_state = (recurrent_state,) + past_key_values.update(last_state, self.layer_idx, q.shape[2]) + + o = self.g_norm(rearrange(o, 'b h l d -> b l (h d)')) + o = self.o_proj(o) + + return o, None, past_key_values + + def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: + param = next(self.parameters()) + state = tuple() + if self.use_short_conv: + if self.share_conv_kernel: + state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) + else: + state += (param.new_zeros(batch_size, self.forget_dim, self.conv_size), + param.new_zeros(batch_size, self.forget_dim, self.conv_size), + param.new_zeros(batch_size, self.input_dim, self.conv_size)) + state += (param.new_zeros(batch_size, self.num_heads, self.head_f_dim, self.head_i_dim),) + return state + + def state_size(self, **kwargs) -> int: + state_size = self.forget_dim * self.head_i_dim + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/finetune/lora/v6/fla/layers/linear_attn.py b/finetune/lora/v6/fla/layers/linear_attn.py new file mode 100644 index 0000000..73b3270 --- /dev/null +++ b/finetune/lora/v6/fla/layers/linear_attn.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from fla.modules import RMSNorm +from fla.modules.feature_map import (DPFPFeatureMap, HadamardFeatureMap, + HedgehogFeatureMap, T2RFeatureMap) +from fla.ops.linear_attn import (chunk_linear_attn, fused_chunk_linear_attn, + fused_recurrent_linear_attn) + + +class LinearAttention(nn.Module): + def __init__( + self, + hidden_size: str = 1024, + expand_k: int = 1.0, + expand_v: int = 1.0, + num_heads: int = 8, + mode: str = 'chunk', + feature_map: str = 'elementwise_product', + tie_feature_map_qk: bool = False, + output_norm: str = 'rmsnorm', + norm_q: bool = False, + norm_k: bool = False, + # standard linear attention normalization + do_feature_map_norm: bool = False, + elementwise_affine: bool = True, + norm_eps: float = 1e-5, + **kwargs, + ): + super().__init__() + assert feature_map in ['elu', 'relu', 'hedgehog', 't2r', 'dpfp', + 'identity', 'elementwise_product'], f"Not supported feature map `{feature_map}`." + + assert output_norm in ['rmsnorm', 'identity'], f"Not supported output norm `{output_norm}`." + + self.hidden_size + self.mode = mode + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.num_heads = num_heads + + assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_qk_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + if feature_map == 'hedgehog': + if tie_feature_map_qk: + self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim) + else: + self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_qk_dim) + self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim) + + elif feature_map == 't2r': + if tie_feature_map_qk: + self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim) + else: + self.feature_map_q = T2RFeatureMap(head_dim=self.head_qk_dim) + self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim) + + elif feature_map == 'elementwise_product': + if tie_feature_map_qk: + self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim) + else: + self.feature_map_q = HadamardFeatureMap(head_dim=self.head_qk_dim) + self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim) + + elif feature_map == 'dpfp': + self.feature_map_q = DPFPFeatureMap(head_dim=self.head_qk_dim) + self.feature_map_k = DPFPFeatureMap(head_dim=self.head_qk_dim) + + elif feature_map == 'elu': + def elu(x): + return F.elu(x) + 1 + self.feature_map_q = elu + self.feature_map_k = elu + + elif feature_map == 'relu': + self.feature_map_q = nn.ReLU() + self.feature_map_k = nn.ReLU() + + elif feature_map == 'identity': + self.feature_map_q = nn.Identity() + self.feature_map_k = nn.Identity() + else: + raise NotImplementedError + + self.do_feature_map_norm = do_feature_map_norm + if output_norm == 'rmsnorm': + self.norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps) + elif output_norm == 'identity': + self.norm = nn.Identity() + else: + raise NotImplementedError + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + self.norm_q = norm_q + self.norm_k = norm_k + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward(self, x): + mode = self.mode + q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + q = self.feature_map_q(q) + k = self.feature_map_k(k) + if self.norm_q: + q = q / (q.sum(-1, keepdim=True) + 1e-4) + if self.norm_k: + k = k / (k.sum(-1, keepdim=True) + 1e-4) + + if mode == 'chunk': + o = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm) + elif mode == 'fused_chunk': + o = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm) + elif mode == 'fused_recurrent': + o = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm) + else: + raise NotImplementedError + o = self.norm(o) + o = rearrange(o, 'b h n d -> b n (h d)') + o = self.o_proj(o) + return o + + +if __name__ == '__main__': + import torch + batch = 4 + seq_len = 1024 + hidden_size = 1024 + x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True) + model = LinearAttention(hidden_size, feature_map='dplp').to(torch.bfloat16).cuda() + y = model(x) + print(y.shape) + y.sum().backward() + print(x.grad.shape) diff --git a/finetune/lora/v6/fla/layers/multiscale_retention.py b/finetune/lora/v6/fla/layers/multiscale_retention.py new file mode 100644 index 0000000..4d143df --- /dev/null +++ b/finetune/lora/v6/fla/layers/multiscale_retention.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache + +from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution +from fla.modules.rotary import RotaryEmbedding +from fla.ops.retention import (chunk_retention, fused_chunk_retention, + fused_recurrent_retention, parallel_retention) + + +class MultiScaleRetention(nn.Module): + r""" + The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa + + Args: + mode (str, Optional): + Which Retention kernel to use. + Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`. + Default: `fused_chunk`. + hidden_size (int, Optional): + The hidden size of the input. Default: 1024. + expand_k (float, Optional): + The expansion ratio for the key dim. Default: 1.0. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 2.0. + num_heads (int, Optional): + The number of heads. Default: 8. + num_kv_heads (int, Optional): + The number of key/value heads, used for MQA. Default: None. + feature_map (str, Optional): + Feature map function applied to queries/keys. Default: None. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `False`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + share_conv_kernel (bool, Optional): + Whether to apply convolutions berfore q/k/v mapping, only taking effects when `use_short_conv`. Default: `True`. + use_output_gate (bool, Optional): + Whether to use output gate. Default: `True`. + gate_fn (str, Optional): + The activation function for the output gate. Default: `swish`. + elementwise_affine (bool, Optional): + If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. + norm_eps (float, Optional): + The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. + fuse_norm (bool, Optional): + Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. + layer_idx (int, Optional): + The index of the layer. Default: None. + """ + + def __init__( + self, + mode: str = 'fused_chunk', + hidden_size: int = 1024, + expand_k: float = 1.0, + expand_v: float = 2.0, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + share_conv_kernel: bool = True, + use_output_gate: bool = True, + gate_fn: str = 'swish', + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + fuse_norm: bool = True, + layer_idx: int = None, + **kwargs + ) -> MultiScaleRetention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.share_conv_kernel = share_conv_kernel + self.use_output_gate = use_output_gate + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_qk_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) + if self.use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + if share_conv_kernel: + self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') + else: + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + if gate_fn == 'swish' and fuse_norm and use_output_gate: + self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps) + self.fuse_norm_and_gate = True + else: + self.fuse_norm_and_gate = False + self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps) + self.gate_fn = ACT2FN[gate_fn] + + # TODO: fix this issue + # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180 + # Ideally, we would want to support arbitrary d_head_qk + assert self.head_qk_dim <= 256, "head_qk_dim must be less than or equal to 256" + self.rotary = RotaryEmbedding(dim=self.head_qk_dim) + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode + + last_state = past_key_values[self.layer_idx] if use_cache else None + if self.use_short_conv: + conv_state = last_state[0] if use_cache else None + if self.share_conv_kernel: + # conv state is updated inplace + hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + else: + conv_state_q = last_state[0] if use_cache else None + conv_state_k = last_state[1] if use_cache else None + conv_state_v = last_state[2] if use_cache else None + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + q = self.q_conv1d(q, attention_mask, conv_state_q) + k = self.k_conv1d(k, attention_mask, conv_state_k) + v = self.v_conv1d(v, attention_mask, conv_state_v) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask.unsqueeze(-1)) + q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) + k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads) + if self.feature_map_fn is not None: + q, k = map(self.feature_map_fn, (q, k)) + + seqlen_offset, max_seqlen = 0, None + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q.shape[1] + seqlen_offset + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1] + max_seqlen = q.shape[1] + max(seqlen_offset) + q, k = self.rotary(q, k, seqlen_offset, max_seqlen) + q = q.transpose(1, 2) + if self.num_kv_groups > 1: + k = repeat(k, 'b t h d -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups) + v = repeat(v, 'b t (h d) -> b (h g) t d', h=self.num_kv_heads, g=self.num_kv_groups) + else: + k, v = rearrange(k, 'b t h d -> b h t d'), rearrange(v, 'b t (h d) -> b h t d', h=self.num_kv_heads) + + state = last_state[-1] if use_cache else None + if mode == 'chunk': + o, recurrent_state = chunk_retention(q, k, v, initial_state=state, output_final_state=use_cache) + elif mode == 'fused_chunk': + o, recurrent_state = fused_chunk_retention(q, k, v, initial_state=state, output_final_state=use_cache) + elif mode == 'parallel': + o, recurrent_state = parallel_retention(q, k, v, initial_state=state, output_final_state=use_cache) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_retention(q, k, v, initial_state=state, output_final_state=use_cache) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + if self.use_short_conv: + if self.share_conv_kernel: + last_state = (conv_state, recurrent_state) + else: + last_state = (conv_state_q, conv_state_k, conv_state_v, recurrent_state) + else: + last_state = (recurrent_state,) + past_key_values.update(last_state, self.layer_idx, q.shape[2]) + + o = rearrange(o, 'b h l d -> b l h d') + if self.use_output_gate: + g = self.g_proj(hidden_states) + if self.fuse_norm_and_gate: + g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads) + o = self.g_norm_swish_gate(o, g) + o = rearrange(o, 'b l h d -> b l (h d)') + else: + o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)') + o = o * self.gate_fn(g) + else: + o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)') + o = self.o_proj(o) + + return o, None, past_key_values + + def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: + param = next(self.parameters()) + state = tuple() + if self.use_short_conv: + if self.share_conv_kernel: + state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) + else: + state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), + param.new_zeros(batch_size, self.key_dim, self.conv_size), + param.new_zeros(batch_size, self.value_dim, self.conv_size)) + state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),) + return state + + def state_size(self, **kwargs) -> int: + state_size = self.key_dim * self.head_v_dim + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/finetune/lora/v6/fla/layers/rebased.py b/finetune/lora/v6/fla/layers/rebased.py new file mode 100644 index 0000000..3dad7b3 --- /dev/null +++ b/finetune/lora/v6/fla/layers/rebased.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- + +""" +https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange + +from fla.modules.feature_map import RebasedFeatureMap +from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn +from fla.ops.rebased import parallel_rebased + + +class ReBasedLinearAttention(nn.Module): + def __init__( + self, + hidden_size: int, + l_max: int = 2048, + feature_dim: int = 16, + num_key_value_heads: int = 16, + num_heads: int = 16, + use_gamma: Optional[bool] = True, + use_beta: Optional[bool] = True, + normalize: Optional[bool] = True, + causal: bool = True, + eps: float = 1e-5, + mode: str = "parallel", + layer_idx: Optional[int] = None, + **kwargs + ) -> ReBasedLinearAttention: + super().__init__() + self.hidden_size = hidden_size + self.l_max = l_max + self.mode = mode + assert self.mode in ["fused_chunk", "parallel", 'chunk'] + + # linear attention + self.feature_dim = feature_dim + self.num_key_value_heads = num_key_value_heads + self.num_heads = num_heads + self.head_dim = self.hidden_size // self.num_key_value_heads + self.use_gamma = use_gamma + self.use_beta = use_beta + self.normalize = normalize + self.causal = causal + + self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize) + self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.dropout = nn.Identity() + self.eps = eps + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward(self, hidden_states: torch.Tensor, **kwargs): + mode = self.mode + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) + q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel')) + if mode == "fused_chunk": + o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1) + elif mode == 'chunk': + o = chunk_linear_attn(q, k, v, normalize=True, scale=1) + elif mode == 'parallel': + assert q.shape[-1] <= 128 + o = parallel_rebased(q, k, v, self.eps, True, True) + o = rearrange(o, "b h l d -> b l (h d)") + o = self.o_proj(o) + o = self.dropout(o) + return o + + # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 + def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): + """ + x (torch.Tensor): tensor of shape (b, d, l) + y (torch.Tensor): tensor of shape (b, d, l) + """ + # hidden_states = hidden_states.transpose(1, 2) + b, l, _ = hidden_states.size() + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + + q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) + k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2) + v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # Linear attention + q, k = self.feature_map(q), self.feature_map(k) + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + + # Compute attention + if self.causal: + y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) + else: + y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) + y = rearrange(y, 'b h l d -> b l (h d)') + y = self.o_proj(y.to(hidden_states.dtype)) + y = self.dropout(y) + return y.to(hidden_states.dtype) + + +if __name__ == '__main__': + batch = 4 + seq_len = 1024 + hidden_size = 1024 + dtype = torch.float32 + x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True) + dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda() + model = ReBasedLinearAttention(hidden_size=hidden_size, mode='parallel').to(dtype).cuda() + + y = model(x) + y.backward(dy, retain_graph=True) + x_grad, x.grad = x.grad, None + print(model.mode) + model.mode = 'fused_chunk' + y2 = model(x) + print(model.mode) + y2.backward(dy) + # assert y.allclose(y2, 0, 1e-4), breakpoint() + # assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint() + print("Pass") diff --git a/finetune/lora/v6/fla/layers/rwkv6.py b/finetune/lora/v6/fla/layers/rwkv6.py new file mode 100644 index 0000000..ec44974 --- /dev/null +++ b/finetune/lora/v6/fla/layers/rwkv6.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- + +# "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892] + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache + +from fla.modules import FusedLayerNormSwishGate, LayerNorm +from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6 + + +class RWKV6Attention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 0.5, + expand_v: float = 1.0, + num_heads: int = 4, + gate_fn: str = 'swish', + proj_low_rank_dim: int = 32, + gate_low_rank_dim: int = 64, + fuse_norm: bool = True, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None, + **kwargs + ) -> RWKV6Attention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.proj_low_rank_dim = proj_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_qk_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.x_proj = nn.Sequential( + LerpLinear(hidden_size, proj_low_rank_dim * 5), + nn.Tanh(), + nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=True) + ) + self.r_proj = DDLerpLinear(hidden_size, self.key_dim) + self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim) + self.k_proj = DDLerpLinear(hidden_size, self.key_dim) + self.v_proj = DDLerpLinear(hidden_size, self.value_dim) + self.g_proj = DDLerpLinear(hidden_size, self.value_dim) + self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_qk_dim)) + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + if gate_fn == 'swish' and fuse_norm: + self.g_norm_swish_gate = FusedLayerNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps) + self.fuse_norm_and_gate = True + else: + self.fuse_norm_and_gate = False + self.g_norm = LayerNorm(self.head_v_dim, elementwise_affine, norm_eps) + self.gate_fn = ACT2FN[gate_fn] + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + if isinstance(module, nn.Parameter): + nn.init.xavier_uniform_(module, gain=2 ** -2.5) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + batch_size, seq_len, hidden_size = hidden_states.size() + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode + + delta = self.time_shift(hidden_states) - hidden_states + x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim) + r, w, k, v, g = torch.einsum('b l n r, n r d-> b l n d', + self.x_proj[1](x), + self.x_proj[2].weight.view(5, -1, hidden_size)).unbind(-2) + r = self.r_proj(hidden_states, r, delta) + w = self.w_proj(hidden_states, w, delta) + k = self.k_proj(hidden_states, k, delta) + v = self.v_proj(hidden_states, v, delta) + g = self.g_proj(hidden_states, g, delta) + + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask.unsqueeze(-1)) + r, w, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (r, w, k, v)) + w = -torch.exp(w) + u = self.bonus + + last_state = past_key_values[self.layer_idx] if use_cache else None + state = last_state[-1] if use_cache else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_rwkv6(r, k, v, w, u, initial_state=state, output_final_state=use_cache) + elif mode == 'chunk': + o, recurrent_state = chunk_rwkv6(r, k, v, w, u, initial_state=state, output_final_state=use_cache) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update((recurrent_state,), self.layer_idx, r.shape[2]) + + o = rearrange(o, 'b h l d -> b l h d') + if self.fuse_norm_and_gate: + g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads) + o = self.g_norm_swish_gate(o, g) + o = rearrange(o, 'b l h d -> b l (h d)') + else: + o = self.g_norm(o) + o = rearrange(o, 'b l h d -> b l (h d)') + o = o * self.gate_fn(g) + o = self.o_proj(o) + + return o, None, past_key_values + + def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: + param = next(self.parameters()) + state = (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim),) + return state + + def state_size(self, **kwargs) -> int: + state_size = self.key_dim * self.head_v_dim + return state_size + + +class LoRA(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + low_rank_dim: int, + bias: Optional[bool] = True + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.low_rank_dim = low_rank_dim + self.bias = bias + + self.lora = nn.Sequential( + nn.Linear(input_dim, low_rank_dim, bias=False), + nn.Tanh(), + nn.Linear(low_rank_dim, output_dim, bias=bias) + ) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(" + s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}" + if not self.bias: + s += f", bias={self.bias}" + s += ")" + return s + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lora(x) + + +class LerpLinear(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + low_rank_dim: Optional[int] = None + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.low_rank_dim = low_rank_dim + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + if low_rank_dim is None: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = LoRA(input_dim, output_dim, low_rank_dim) + self.mu = nn.Parameter(torch.zeros(input_dim)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}" + if self.low_rank_dim is not None: + s += f", low_rank_dim={self.low_rank_dim}" + s += ")" + return s + + def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor: + if delta is None: + shifted = self.time_shift(x) + if len(shifted.shape) == 2: + shifted = shifted.unsqueeze(1) + delta = shifted - x + return self.linear(x + delta * self.mu) + + +class DDLerpLinear(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + low_rank_dim: Optional[int] = None + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.low_rank_dim = low_rank_dim + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + if low_rank_dim is None: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = LoRA(input_dim, output_dim, low_rank_dim) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}" + if self.low_rank_dim is not None: + s += f", low_rank_dim={self.low_rank_dim}" + s += ")" + return s + + def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor: + if delta is None: + shifted = self.time_shift(x) + if len(shifted.shape) == 2: + shifted = shifted.unsqueeze(1) + delta = shifted - x + return self.linear(x + delta * mu) diff --git a/finetune/lora/v6/fla/layers/simple_gla.py b/finetune/lora/v6/fla/layers/simple_gla.py new file mode 100644 index 0000000..43a637a --- /dev/null +++ b/finetune/lora/v6/fla/layers/simple_gla.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers.activations import ACT2FN + +from fla.modules import FusedRMSNormSwishGate, RMSNorm +from fla.ops.simple_gla import chunk_simple_gla + + +class SimpleGatedLinearAttention(nn.Module): + r""" + The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa + This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise. + + Args: + mode (str, Optional): + Which GLA kernel to use. + Currently available: `chunk`. + Default: `chunk`. + hidden_size (int, Optional): + The hidden size of the input. Default: 1024. + expand_k (float, Optional): + The expansion ratio for the key dim. Default: 0.5. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 1.0. + num_heads (int, Optional): + The number of heads. Default: 4. + gate_fn (str, Optional): + The activation function for the output gate. Default: `swish`. + elementwise_affine (bool, Optional): + If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. + norm_eps (float, Optional): + The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. + gate_logit_normalizer (int, Optional): + The normalizer for the gate logits, appied after `logsigmoid`. Default: 16. + fuse_norm (bool, Optional): + Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. + layer_idx (int, Optional): + The index of the layer. Default: None. + """ + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 1.0, + expand_v: float = 2.0, + num_heads: int = 4, + gate_fn: str = 'swish', + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_logit_normalizer: int = 16, + fuse_norm: bool = True, + **kwargs + ) -> SimpleGatedLinearAttention: + super().__init__() + self.hidden_size = hidden_size + + self.mode = mode + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + assert mode in ['chunk'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + self.num_heads = num_heads + self.head_qk_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.gate_fn = ACT2FN[gate_fn] + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + self.gk_proj = nn.Linear(hidden_size, self.num_heads) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + if gate_fn == 'swish' and fuse_norm: + self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps) + self.fuse_norm_and_gate = True + else: + self.fuse_norm_and_gate = False + self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps) + + self.gate_logit_normalizer = gate_logit_normalizer + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward(self, x): + mode = self.mode + q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) + gk = rearrange(self.gk_proj(x), 'b n h -> b h n') + gk = (F.logsigmoid(gk) / self.gate_logit_normalizer) + + if mode == 'chunk': + o = chunk_simple_gla(q, k, v, gk) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + o = rearrange(o, 'b h l d -> b l h d') + g = self.g_proj(x) + + if self.fuse_norm_and_gate: + g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads) + o = self.g_norm_swish_gate(o, g) + o = rearrange(o, 'b l h d -> b l (h d)') + else: + o = self.g_norm(o) + o = rearrange(o, 'b l h d -> b l (h d)') + o = o * self.gate_fn(g) + o = self.o_proj(o) + return o + + +if __name__ == '__main__': + batch = 4 + seq_len = 1024 + + hidden_size = 2048 + x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True) + model = SimpleGatedLinearAttention(hidden_size=hidden_size, mode='chunk').to(torch.bfloat16).cuda() + y = model(x) + print(y.shape) + y.sum().backward() + print(x.grad.shape) diff --git a/finetune/lora/v6/fla/models/__init__.py b/finetune/lora/v6/fla/models/__init__.py new file mode 100644 index 0000000..ffc48e1 --- /dev/null +++ b/finetune/lora/v6/fla/models/__init__.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- + +from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel +from fla.models.delta_net import (DeltaNetConfig, DeltaNetForCausalLM, + DeltaNetModel) +from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel +from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel +from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model +from fla.models.linear_attn import (LinearAttentionConfig, + LinearAttentionForCausalLM, + LinearAttentionModel) +from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel +from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel +from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model +from fla.models.transformer import (TransformerConfig, TransformerForCausalLM, + TransformerModel) + +__all__ = [ + 'ABCConfig', 'ABCForCausalLM', 'ABCModel', + 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel', + 'GLAConfig', 'GLAForCausalLM', 'GLAModel', + 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel', + 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model', + 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel', + 'MambaConfig', 'MambaForCausalLM', 'MambaModel', + 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel', + 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model', + 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel' +] diff --git a/finetune/lora/v6/fla/models/abc/__init__.py b/finetune/lora/v6/fla/models/abc/__init__.py new file mode 100644 index 0000000..f7021f2 --- /dev/null +++ b/finetune/lora/v6/fla/models/abc/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.abc.configuration_abc import ABCConfig +from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel + +AutoConfig.register(ABCConfig.model_type, ABCConfig) +AutoModel.register(ABCConfig, ABCModel) +AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM) + + +__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel'] diff --git a/finetune/lora/v6/fla/models/abc/configuration_abc.py b/finetune/lora/v6/fla/models/abc/configuration_abc.py new file mode 100644 index 0000000..3c18557 --- /dev/null +++ b/finetune/lora/v6/fla/models/abc/configuration_abc.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class ABCConfig(PretrainedConfig): + + model_type = 'abc' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + gate_low_rank_dim: int = 16, + clamp_min: float = -32, + clamp_max: float = 32, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_slots: Optional[int] = 64, + use_short_conv: bool = True, + conv_size: int = 4, + share_conv_kernel: bool = True, + exapnd_k: float = 0.5, + exapnd_v: float = 1, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + initializer_range: float = 0.02, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.gate_low_rank_dim = gate_low_rank_dim + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_slots = num_slots + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.share_conv_kernel = share_conv_kernel + self.expand_k = exapnd_k + self.expand_v = exapnd_v + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/finetune/lora/v6/fla/models/abc/modeling_abc.py b/finetune/lora/v6/fla/models/abc/modeling_abc.py new file mode 100644 index 0000000..431df20 --- /dev/null +++ b/finetune/lora/v6/fla/models/abc/modeling_abc.py @@ -0,0 +1,394 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.abc import ABCAttention +from fla.models.abc.configuration_abc import ABCConfig +from fla.models.utils import RecurrentCache +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class ABCMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> ABCMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class ABCBlock(nn.Module): + def __init__(self, config: ABCConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = ABCAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_slots=config.num_slots, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + share_conv_kernel=config.share_conv_kernel, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + clamp_min=config.clamp_min, + clamp_max=config.clamp_max, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = ABCMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class ABCPreTrainedModel(PreTrainedModel): + + config_class = ABCConfig + supports_gradient_checkpointing = True + _no_split_modules = ['ABCBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class ABCModel(ABCPreTrainedModel): + + def __init__(self, config: ABCConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([ABCBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`ABCModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: + return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class ABCForCausalLM(ABCPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = ABCModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + model_inputs['past_key_values'] = past_key_values + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/finetune/lora/v6/fla/models/delta_net/__init__.py b/finetune/lora/v6/fla/models/delta_net/__init__.py new file mode 100644 index 0000000..6df3841 --- /dev/null +++ b/finetune/lora/v6/fla/models/delta_net/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.delta_net.configuration_delta_net import \ + DeltaNetConfig +from fla.models.delta_net.modeling_delta_net import ( + DeltaNetForCausalLM, DeltaNetModel) + +AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig) +AutoModel.register(DeltaNetConfig, DeltaNetModel) +AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM) + +__all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel'] diff --git a/finetune/lora/v6/fla/models/delta_net/configuration_delta_net.py b/finetune/lora/v6/fla/models/delta_net/configuration_delta_net.py new file mode 100644 index 0000000..c8eaaae --- /dev/null +++ b/finetune/lora/v6/fla/models/delta_net/configuration_delta_net.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class DeltaNetConfig(PretrainedConfig): + + model_type = 'delta_net' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + share_conv_kernel: bool = False, + use_rope: bool = False, + use_beta: bool = True, + use_output_norm: bool = True, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + attn_mode: str = "chunk", + qk_norm: str = 'l2', + qk_activation: str = 'silu', + chunk_size: int = 64, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + rms_norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.attn_mode = attn_mode + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.share_conv_kernel = share_conv_kernel + self.use_rope = use_rope + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.qk_norm = qk_norm + self.qk_activation = qk_activation + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/finetune/lora/v6/fla/models/delta_net/modeling_delta_net.py b/finetune/lora/v6/fla/models/delta_net/modeling_delta_net.py new file mode 100644 index 0000000..cec4bdc --- /dev/null +++ b/finetune/lora/v6/fla/models/delta_net/modeling_delta_net.py @@ -0,0 +1,405 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.delta_net import DeltaNet +from fla.models.delta_net.configuration_delta_net import DeltaNetConfig +from fla.models.utils import RecurrentCache +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class DeltaNetMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> DeltaNetMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class DeltaNetBlock(nn.Module): + def __init__(self, config: DeltaNetConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps) + self.attn = DeltaNet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_rope=config.use_rope, + use_beta=config.use_beta, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + share_conv_kernel=config.share_conv_kernel, + layer_idx=layer_idx, + qk_norm=config.qk_norm, + qk_activation=config.qk_activation + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps) + self.mlp = DeltaNetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class DeltaNetPreTrainedModel(PreTrainedModel): + + config_class = DeltaNetConfig + supports_gradient_checkpointing = True + _no_split_modules = ['DeltaNetBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class DeltaNetModel(DeltaNetPreTrainedModel): + + def __init__(self, config: DeltaNetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([DeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`DeltaNetModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = past_key_values + # if use_cache: + # next_cache = past_key_values.to_legacy_cache() + if not return_dict: + return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class DeltaNetForCausalLM(DeltaNetPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeltaNetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + # breakpoint() + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/finetune/lora/v6/fla/models/gla/__init__.py b/finetune/lora/v6/fla/models/gla/__init__.py new file mode 100644 index 0000000..edccb51 --- /dev/null +++ b/finetune/lora/v6/fla/models/gla/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.gla.configuration_gla import GLAConfig +from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel + +AutoConfig.register(GLAConfig.model_type, GLAConfig) +AutoModel.register(GLAConfig, GLAModel) +AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM) + + +__all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel'] diff --git a/finetune/lora/v6/fla/models/gla/configuration_gla.py b/finetune/lora/v6/fla/models/gla/configuration_gla.py new file mode 100644 index 0000000..f8bf56a --- /dev/null +++ b/finetune/lora/v6/fla/models/gla/configuration_gla.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class GLAConfig(PretrainedConfig): + + model_type = 'gla' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + attn_mode: str = "chunk", + use_short_conv: bool = False, + conv_size: int = 4, + share_conv_kernel: bool = True, + use_output_gate: bool = True, + clamp_min: Optional[float] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_gk: bool = True, + use_gv: bool = False, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.attn_mode = attn_mode + self.clamp_min = clamp_min + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_gk = use_gk + self.use_gv = use_gv + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.share_conv_kernel = share_conv_kernel + self.use_output_gate = use_output_gate + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/finetune/lora/v6/fla/models/gla/modeling_gla.py b/finetune/lora/v6/fla/models/gla/modeling_gla.py new file mode 100644 index 0000000..9ad4352 --- /dev/null +++ b/finetune/lora/v6/fla/models/gla/modeling_gla.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.gla import GatedLinearAttention +from fla.models.gla.configuration_gla import GLAConfig +from fla.models.utils import RecurrentCache +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class GLAMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> GLAMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class GLABlock(nn.Module): + def __init__(self, config: GLAConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = GatedLinearAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + share_conv_kernel=config.share_conv_kernel, + use_output_gate=config.use_output_gate, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + clamp_min=config.clamp_min, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = GLAMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class GLAPreTrainedModel(PreTrainedModel): + + config_class = GLAConfig + supports_gradient_checkpointing = True + _no_split_modules = ['GLABlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class GLAModel(GLAPreTrainedModel): + + def __init__(self, config: GLAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: + return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class GLAForCausalLM(GLAPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = GLAModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/finetune/lora/v6/fla/models/hgrn/__init__.py b/finetune/lora/v6/fla/models/hgrn/__init__.py new file mode 100644 index 0000000..3b29a3d --- /dev/null +++ b/finetune/lora/v6/fla/models/hgrn/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.hgrn.configuration_hgrn import HGRNConfig +from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel + +AutoConfig.register(HGRNConfig.model_type, HGRNConfig) +AutoModel.register(HGRNConfig, HGRNModel) +AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM) + + +__all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel'] diff --git a/finetune/lora/v6/fla/models/hgrn/configuration_hgrn.py b/finetune/lora/v6/fla/models/hgrn/configuration_hgrn.py new file mode 100644 index 0000000..6b70667 --- /dev/null +++ b/finetune/lora/v6/fla/models/hgrn/configuration_hgrn.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class HGRNConfig(PretrainedConfig): + + model_type = 'hgrn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + vocab_size: int = 32000, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: Optional[int] = 1, + expand_ratio: Optional[int] = 1, + use_short_conv: bool = False, + conv_size: int = 4, + share_conv_kernel: bool = True, + use_lower_bound: bool = True, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.attn_mode = attn_mode + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.expand_ratio = expand_ratio + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.share_conv_kernel = share_conv_kernel + self.use_lower_bound = use_lower_bound + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/finetune/lora/v6/fla/models/hgrn/modeling_hgrn.py b/finetune/lora/v6/fla/models/hgrn/modeling_hgrn.py new file mode 100644 index 0000000..b41b274 --- /dev/null +++ b/finetune/lora/v6/fla/models/hgrn/modeling_hgrn.py @@ -0,0 +1,407 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.hgrn import HGRNAttention +from fla.models.hgrn.configuration_hgrn import HGRNConfig +from fla.models.utils import RecurrentCache +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class HGRNMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> HGRNMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class HGRNBlock(nn.Module): + def __init__(self, config: HGRNConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = HGRNAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + num_heads=config.num_heads, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + share_conv_kernel=config.share_conv_kernel, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = HGRNMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class HGRNPreTrainedModel(PreTrainedModel): + + config_class = HGRNConfig + supports_gradient_checkpointing = True + _no_split_modules = ['HGRNBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class HGRNModel(HGRNPreTrainedModel): + + def __init__(self, config: HGRNConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if config.use_lower_bound: + self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size)) + self.layers = nn.ModuleList([HGRNBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`HGRNModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + if self.config.use_lower_bound: + lower_bounds = self.lower_bounds.softmax(0) + lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0] + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + lower_bound = lower_bounds[i] if self.config.use_lower_bound else None + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + lower_bound + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: + return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class HGRNForCausalLM(HGRNPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = HGRNModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/finetune/lora/v6/fla/models/hgrn2/__init__.py b/finetune/lora/v6/fla/models/hgrn2/__init__.py new file mode 100644 index 0000000..306b808 --- /dev/null +++ b/finetune/lora/v6/fla/models/hgrn2/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config +from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model + +AutoConfig.register(HGRN2Config.model_type, HGRN2Config) +AutoModel.register(HGRN2Config, HGRN2Model) +AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM) + + +__all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model'] diff --git a/finetune/lora/v6/fla/models/hgrn2/configuration_hgrn2.py b/finetune/lora/v6/fla/models/hgrn2/configuration_hgrn2.py new file mode 100644 index 0000000..5f5382c --- /dev/null +++ b/finetune/lora/v6/fla/models/hgrn2/configuration_hgrn2.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class HGRN2Config(PretrainedConfig): + + model_type = 'hgrn2' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + attn_mode: str = "chunk", + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + share_conv_kernel: bool = True, + use_lower_bound: bool = True, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.attn_mode = attn_mode + self.num_heads = num_heads + self.expand_ratio = expand_ratio + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.share_conv_kernel = share_conv_kernel + self.use_lower_bound = use_lower_bound + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/finetune/lora/v6/fla/models/hgrn2/modeling_hgrn2.py b/finetune/lora/v6/fla/models/hgrn2/modeling_hgrn2.py new file mode 100644 index 0000000..0f530a2 --- /dev/null +++ b/finetune/lora/v6/fla/models/hgrn2/modeling_hgrn2.py @@ -0,0 +1,407 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.hgrn2 import HGRN2Attention +from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config +from fla.models.utils import RecurrentCache +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class HGRN2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> HGRN2MLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class HGRN2Block(nn.Module): + def __init__(self, config: HGRN2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = HGRN2Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + num_heads=config.num_heads, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + share_conv_kernel=config.share_conv_kernel, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = HGRN2MLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class HGRN2PreTrainedModel(PreTrainedModel): + + config_class = HGRN2Config + supports_gradient_checkpointing = True + _no_split_modules = ['HGRN2Block'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class HGRN2Model(HGRN2PreTrainedModel): + + def __init__(self, config: HGRN2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if config.use_lower_bound: + self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size)) + self.layers = nn.ModuleList([HGRN2Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`HGRN2Model` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + if self.config.use_lower_bound: + lower_bounds = self.lower_bounds.softmax(0) + lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0] + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + lower_bound = lower_bounds[i] if self.config.use_lower_bound else None + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + lower_bound + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: + return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class HGRN2ForCausalLM(HGRN2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = HGRN2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/finetune/lora/v6/fla/models/linear_attn/__init__.py b/finetune/lora/v6/fla/models/linear_attn/__init__.py new file mode 100644 index 0000000..72d5d02 --- /dev/null +++ b/finetune/lora/v6/fla/models/linear_attn/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.linear_attn.configuration_linear_attn import \ + LinearAttentionConfig +from fla.models.linear_attn.modeling_linear_attn import ( + LinearAttentionForCausalLM, LinearAttentionModel) + +AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig) +AutoModel.register(LinearAttentionConfig, LinearAttentionModel) +AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM) + +__all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel'] diff --git a/finetune/lora/v6/fla/models/linear_attn/configuration_linear_attn.py b/finetune/lora/v6/fla/models/linear_attn/configuration_linear_attn.py new file mode 100644 index 0000000..35d6d20 --- /dev/null +++ b/finetune/lora/v6/fla/models/linear_attn/configuration_linear_attn.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class LinearAttentionConfig(PretrainedConfig): + + model_type = 'linear_attn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + attn_mode: str = "fused_chunk", + feature_map: str = "elementwise_product", + tie_feature_map_qk: bool = False, + norm_q: bool = False, + norm_k: bool = False, + norm_feature_map: bool = False, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.attn_mode = attn_mode + self.feature_map = feature_map + self.tie_feature_map_qk = tie_feature_map_qk + self.norm_q = norm_q + self.norm_k = norm_k + self.norm_feature_map = norm_feature_map + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/finetune/lora/v6/fla/models/linear_attn/modeling_linear_attn.py b/finetune/lora/v6/fla/models/linear_attn/modeling_linear_attn.py new file mode 100644 index 0000000..cfcc5c0 --- /dev/null +++ b/finetune/lora/v6/fla/models/linear_attn/modeling_linear_attn.py @@ -0,0 +1,424 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.linear_attn import LinearAttention +from fla.models.linear_attn.configuration_linear_attn import \ + LinearAttentionConfig +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class LinearAttentionMLP(nn.Module): + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> LinearAttentionMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class LinearAttentionBlock(nn.Module): + def __init__(self, config: LinearAttentionConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = LinearAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + mode=config.attn_mode, + feature_map=config.feature_map, + tie_feature_map_qk=config.tie_feature_map_qk, + norm_q=config.norm_q, + norm_k=config.norm_k, + do_feature_map_norm=config.norm_feature_map, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = LinearAttentionMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + # currently not supported + attn_weights, present_key_value = None, None + + hidden_states = self.attn_norm(hidden_states) + hidden_states = self.attn(hidden_states) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LinearAttentionPreTrainedModel(PreTrainedModel): + config_class = LinearAttentionConfig + supports_gradient_checkpointing = True + _no_split_modules = ['LinearAttentionBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class LinearAttentionModel(LinearAttentionPreTrainedModel): + + def __init__(self, config: LinearAttentionConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LinearAttentionBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn( + "`LinearAttentionModel` does not support output attention weights now, " + "so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + _, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + _, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LinearAttentionForCausalLM(LinearAttentionPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LinearAttentionModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exc: + # Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'" + if 'past_key_values' in str(exc): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exc + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + state: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs + ): + # only last token for inputs_ids if the state is passed along. + if state is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and state is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + model_inputs["state"] = state + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/finetune/lora/v6/fla/models/mamba/__init__.py b/finetune/lora/v6/fla/models/mamba/__init__.py new file mode 100644 index 0000000..a0eff2e --- /dev/null +++ b/finetune/lora/v6/fla/models/mamba/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mamba.configuration_mamba import MambaConfig +from fla.models.mamba.modeling_mamba import (MambaBlock, MambaForCausalLM, + MambaModel) + +AutoConfig.register(MambaConfig.model_type, MambaConfig, True) +AutoModel.register(MambaConfig, MambaModel, True) +AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True) + + +__all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock'] diff --git a/finetune/lora/v6/fla/models/mamba/configuration_mamba.py b/finetune/lora/v6/fla/models/mamba/configuration_mamba.py new file mode 100644 index 0000000..0467c05 --- /dev/null +++ b/finetune/lora/v6/fla/models/mamba/configuration_mamba.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MAMBA configuration""" + +import math + +from transformers.configuration_utils import PretrainedConfig + + +class MambaConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MAMBA + [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50280): + Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MambaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 16): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the model. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 0): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. + If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the the discretization projection matrix. + `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_scale (`float`, *optional*, defaults to 1.0): + Scale used used to scale `dt_proj.bias`. + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_init_scheme (`float`, *optional*, defaults to `"random"`): + Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]` + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + + + Example: + + ```python + >>> from transformers import MambaConfig, MambaModel + + >>> # Initializing a Mamba configuration + >>> configuration = MambaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = MambaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mamba" + + def __init__( + self, + vocab_size=32000, + hidden_size=2048, + state_size=16, + num_hidden_layers=48, + layer_norm_epsilon=1e-5, + pad_token_id= 0, + bos_token_id= 1, + eos_token_id= 2, + expand=2, + conv_kernel=4, + use_bias=False, + use_conv_bias=True, + hidden_act="silu", + initializer_range=0.1, + residual_in_fp32=False, + time_step_rank="auto", + time_step_scale=1.0, + time_step_min=0.001, + time_step_max=0.1, + time_step_init_scheme="random", + time_step_floor=1e-4, + rescale_prenorm_residual=False, + use_cache=True, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + tie_word_embeddings: bool = False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.layer_norm_epsilon = layer_norm_epsilon + self.conv_kernel = conv_kernel + self.expand = expand + self.intermediate_size = int(expand * self.hidden_size) + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_scale = time_step_scale + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_init_scheme = time_step_init_scheme + self.time_step_floor = time_step_floor + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/finetune/lora/v6/fla/models/mamba/modeling_mamba.py b/finetune/lora/v6/fla/models/mamba/modeling_mamba.py new file mode 100644 index 0000000..5b9760f --- /dev/null +++ b/finetune/lora/v6/fla/models/mamba/modeling_mamba.py @@ -0,0 +1,605 @@ +# coding=utf-8 +# Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MAMBA model.""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging + +from fla.models.mamba.configuration_mamba import MambaConfig +from fla.modules import FusedCrossEntropyLoss, RMSNorm + +logger = logging.get_logger(__name__) + +try: + from mamba_ssm.ops.selective_scan_interface import (mamba_inner_fn, + selective_scan_fn) + from mamba_ssm.ops.triton.selective_state_update import \ + selective_state_update +except ImportError: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +) + + +class MambaCache: + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + self.seqlen_offset = 0 + self.dtype = dtype + intermediate_size = config.intermediate_size + ssm_state_size = config.state_size + conv_kernel_size = config.conv_kernel + + self.conv_states = { + i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) + for i in range(config.num_hidden_layers) + } + + +class MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = config.time_step_rank + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.intermediate_size, + padding=config.conv_kernel - 1, + ) + + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + # projection of the input hidden states + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) + # selective projection used to make dt, B and C input dependant + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + # time step projection (discretization) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.intermediate_size)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of " + "`(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the naive implementation. " + "To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Optional[MambaCache] = None): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias if self.use_conv_bias else None, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias.float() if self.use_bias else None, + -torch.exp(self.A_log.float()), + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + + else: + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if cache_params is not None and cache_params.seqlen_offset > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if cache_params is not None and cache_params.seqlen_offset > 0: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + # fmt: off + def slow_forward(self, input_states, cache_params: Optional[MambaCache] = None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + # [batch, 2 * intermediate_size, seq_len] + projected_states = self.in_proj(input_states).transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + if cache_params.seqlen_offset > 0: + # [batch, intermediate_size, conv_kernel_size] + conv_state = cache_params.conv_states[self.layer_idx] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + conv_state[:, :, -1] = hidden_states[:, :, 0] + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + # [batch, intermediate_size, 1] : decoding + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) + else: + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + # [batch, intermediate_size, seq_len] + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + # [batch, intermediate_size, seq_len] + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + # [batch, seq_len, intermediate_size] + discrete_time_step = self.dt_proj(time_step) + # [batch, intermediate_size, seq_len] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + # [intermediate_size, ssm_state_size] + A = -torch.exp(self.A_log.float()) + # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) + # [batch, intermediade_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + scan_outputs = [] + for i in range(seq_len): + # [batch, intermediade_size, ssm_state] + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] + # [batch, intermediade_size, 1] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) + scan_outputs.append(scan_output[:, :, 0]) + # [batch, seq_len, intermediade_size] + scan_output = torch.stack(scan_outputs, dim=-1) + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = (scan_output * self.act(gate)) + + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + # [batch, seq_len, hidden_size] + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) + return contextualized_states + # fmt: on + + def forward(self, hidden_states, cache_params: Optional[MambaCache] = None): + if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params) + return self.slow_forward(hidden_states, cache_params) + + +class MambaBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = MambaMixer(config, layer_idx=layer_idx) + + def forward(self, hidden_states, cache_params: Optional[MambaCache] = None): + residual = hidden_states + hidden_states = self.norm(hidden_states) + # if self.residual_in_fp32: + # residual = residual.to(torch.float32) + hidden_states = self.mixer(hidden_states, cache_params=cache_params) + hidden_states = residual + hidden_states + return hidden_states + + +class MambaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MambaConfig + base_model_prefix = "backbone" + _no_split_modules = ["MambaBlock"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, MambaMixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale + if self.config.time_step_init_scheme == "constant": + nn.init.constant_(module.dt_proj.weight, dt_init_std) + elif self.config.time_step_init_scheme == "random": + nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) + + dt = torch.exp( + torch.rand(self.config.intermediate_size) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_proj.bias.copy_(inv_dt) + module.dt_proj.bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_layers) + + +@dataclass +class MambaOutput(ModelOutput): + """ + Class for the MAMBA model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MambaCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class MambaModel(MambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[MambaCache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it + ) -> Union[Tuple, MambaOutput]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if cache_params is None and use_cache: + cache_params = MambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params) + else: + hidden_states = mixer_block(hidden_states, cache_params=cache_params) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return MambaOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class MambaForCausalLM(MambaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = MambaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def _update_model_kwargs_for_generation( + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs + ) -> Dict[str, Any]: + model_kwargs["cache_params"] = outputs.get("cache_params", None) + return model_kwargs + + def prepare_inputs_for_generation( + self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, attention_mask=None, **kwargs + ): + # only last token for inputs_ids if the state is passed along. + if cache_params is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs["cache_params"] = cache_params + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[MambaCache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, MambaCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + ) + hidden_states = mamba_outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + mamba_outputs[1:] + return (loss,) + output if loss is not None else output + + return MambaCausalLMOutput( + loss=loss, + logits=logits, + cache_params=mamba_outputs.cache_params, + hidden_states=mamba_outputs.hidden_states, + ) diff --git a/finetune/lora/v6/fla/models/retnet/__init__.py b/finetune/lora/v6/fla/models/retnet/__init__.py new file mode 100644 index 0000000..ad7d9e9 --- /dev/null +++ b/finetune/lora/v6/fla/models/retnet/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.retnet.configuration_retnet import RetNetConfig +from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel + +AutoConfig.register(RetNetConfig.model_type, RetNetConfig) +AutoModel.register(RetNetConfig, RetNetModel) +AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM) + + +__all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel'] diff --git a/finetune/lora/v6/fla/models/retnet/configuration_retnet.py b/finetune/lora/v6/fla/models/retnet/configuration_retnet.py new file mode 100644 index 0000000..b01bda8 --- /dev/null +++ b/finetune/lora/v6/fla/models/retnet/configuration_retnet.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class RetNetConfig(PretrainedConfig): + + model_type = 'retnet' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 2, + hidden_ratio: Optional[int] = 2, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + attn_mode: str = "fused_chunk", + hidden_act: str = "swish", + use_short_conv: bool = False, + conv_size: int = 4, + share_conv_kernel: bool = True, + use_output_gate: bool = True, + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + **kwargs + ) -> RetNetConfig: + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.attn_mode = attn_mode + self.hidden_act = hidden_act + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.share_conv_kernel = share_conv_kernel + self.use_output_gate = use_output_gate + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/finetune/lora/v6/fla/models/retnet/modeling_retnet.py b/finetune/lora/v6/fla/models/retnet/modeling_retnet.py new file mode 100644 index 0000000..49a3eb2 --- /dev/null +++ b/finetune/lora/v6/fla/models/retnet/modeling_retnet.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.multiscale_retention import MultiScaleRetention +from fla.models.retnet.configuration_retnet import RetNetConfig +from fla.models.utils import RecurrentCache +from fla.modules import FusedCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu_linear + +logger = logging.get_logger(__name__) + + +class RetNetMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> RetNetMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class RetNetBlock(nn.Module): + def __init__(self, config: RetNetConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = MultiScaleRetention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + use_output_gate=config.use_output_gate, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = RetNetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class RetNetPreTrainedModel(PreTrainedModel): + + config_class = RetNetConfig + supports_gradient_checkpointing = True + _no_split_modules = ['RetNetBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class RetNetModel(RetNetPreTrainedModel): + + def __init__(self, config: RetNetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [RetNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn( + "`RetNetModel` does not support output attention weights now, so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_len = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_len = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: + return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class RetNetForCausalLM(RetNetPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RetNetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + # Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'" + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/finetune/lora/v6/fla/models/rwkv6/__init__.py b/finetune/lora/v6/fla/models/rwkv6/__init__.py new file mode 100644 index 0000000..942c6dc --- /dev/null +++ b/finetune/lora/v6/fla/models/rwkv6/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config +from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model + +AutoConfig.register(RWKV6Config.model_type, RWKV6Config) +AutoModel.register(RWKV6Config, RWKV6Model) +AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM) + + +__all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model'] diff --git a/finetune/lora/v6/fla/models/rwkv6/configuration_rwkv6.py b/finetune/lora/v6/fla/models/rwkv6/configuration_rwkv6.py new file mode 100644 index 0000000..ff187a8 --- /dev/null +++ b/finetune/lora/v6/fla/models/rwkv6/configuration_rwkv6.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class RWKV6Config(PretrainedConfig): + + model_type = 'rwkv6' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + vocab_size: int = 32000, + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + hidden_ratio: Optional[int] = 3.5, + intermediate_size: Optional[int] = None, + use_glu: Optional[bool] = False, + num_hidden_layers: int = 24, + num_heads: int = 4, + proj_low_rank_dim: int = 32, + gate_low_rank_dim: int = 64, + hidden_act: str = "sqrelu", + max_position_embeddings: int = 2048, + eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.use_glu = use_glu + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.proj_low_rank_dim = proj_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + self.attn_mode = attn_mode + self.hidden_act = hidden_act + self.eps = eps + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/finetune/lora/v6/fla/models/rwkv6/modeling_rwkv6.py b/finetune/lora/v6/fla/models/rwkv6/modeling_rwkv6.py new file mode 100644 index 0000000..ef701e0 --- /dev/null +++ b/finetune/lora/v6/fla/models/rwkv6/modeling_rwkv6.py @@ -0,0 +1,443 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.layers.rwkv6 import LerpLinear, RWKV6Attention +from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config +from fla.models.utils import RecurrentCache +from fla.modules import FusedCrossEntropyLoss, LayerNorm +from fla.modules.activations import ACT2FN, swiglu_linear + +logger = logging.get_logger(__name__) + + +class RWKV6FeedForward(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'sqrelu', + layer_idx: int = None + ) -> RWKV6FeedForward: + super().__init__() + + self.hidden_size = hidden_size + if hidden_ratio is None: + hidden_ratio = 3.5 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio) + intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.key = LerpLinear(hidden_size, intermediate_size) + self.value = nn.Linear(intermediate_size, hidden_size) + self.receptance = LerpLinear(hidden_size, hidden_size) + self.act_fn = ACT2FN[hidden_act] + + self.layer_idx = layer_idx + + def forward(self, x: torch.Tensor, state: Optional[torch.Tensor] = None) -> torch.Tensor: + if state is not None: + raise NotImplementedError("Past state is not yet supported in `RWKV6FeedForward`.") + shifted = self.time_shift(x) + if len(shifted.shape) == 2: + shifted = shifted.unsqueeze(1) + delta = shifted - x + key = self.act_fn(self.key(x, delta)) + value = self.value(key) + receptance = self.receptance(x, delta) + return receptance.sigmoid() * value + + +class RWKV6GLU(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish', + layer_idx: int = None + ) -> RWKV6GLU: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.layer_idx = layer_idx + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class RWKV6Block(nn.Module): + def __init__(self, config: RWKV6Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = LayerNorm(hidden_size=config.hidden_size, eps=config.eps) + self.attn = RWKV6Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + proj_low_rank_dim=config.proj_low_rank_dim, + gate_low_rank_dim=config.gate_low_rank_dim, + eps=config.eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.ffn_norm = LayerNorm(hidden_size=config.hidden_size, eps=config.eps) + self.ffn = (RWKV6GLU if config.use_glu else RWKV6FeedForward)( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + layer_idx=layer_idx + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + hidden_states, residual = self.ffn_norm(hidden_states, residual, True) + hidden_states = self.ffn(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class RWKV6PreTrainedModel(PreTrainedModel): + + config_class = RWKV6Config + supports_gradient_checkpointing = True + _no_split_modules = ['RWKV6Block'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Parameter): + nn.init.normal_(module, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class RWKV6Model(RWKV6PreTrainedModel): + + def __init__(self, config: RWKV6Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([RWKV6Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = LayerNorm(config.hidden_size, eps=config.eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`RWKV6Model` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache: + if past_key_values is None: + past_key_values = [layer.attn.init_state(batch_size) for layer in self.layers] + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: + return tuple(x for x in [hidden_states, next_cache, all_hidden_states, all_attns] if x is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class RWKV6ForCausalLM(RWKV6PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RWKV6Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + if not isinstance(past_key_values, RecurrentCache): + past_key_values = RecurrentCache.from_legacy_cache(past_key_values, input_ids.shape[1] - 1) + input_ids, attention_mask = input_ids[:, -1:], attention_mask[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[List[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/finetune/lora/v6/fla/models/transformer/__init__.py b/finetune/lora/v6/fla/models/transformer/__init__.py new file mode 100644 index 0000000..47df999 --- /dev/null +++ b/finetune/lora/v6/fla/models/transformer/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.transformer.configuration_transformer import TransformerConfig +from fla.models.transformer.modeling_transformer import ( + TransformerForCausalLM, TransformerModel) + +AutoConfig.register(TransformerConfig.model_type, TransformerConfig) +AutoModel.register(TransformerConfig, TransformerModel) +AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM) + + +__all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'] diff --git a/finetune/lora/v6/fla/models/transformer/configuration_transformer.py b/finetune/lora/v6/fla/models/transformer/configuration_transformer.py new file mode 100644 index 0000000..10e7fdc --- /dev/null +++ b/finetune/lora/v6/fla/models/transformer/configuration_transformer.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class TransformerConfig(PretrainedConfig): + + model_type = 'transformer' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + attention_bias: bool = False, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + self.attention_bias = attention_bias + self.fuse_cross_entropy = fuse_cross_entropy + self.fuse_norm = fuse_norm + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/finetune/lora/v6/fla/models/transformer/modeling_transformer.py b/finetune/lora/v6/fla/models/transformer/modeling_transformer.py new file mode 100644 index 0000000..9d51f73 --- /dev/null +++ b/finetune/lora/v6/fla/models/transformer/modeling_transformer.py @@ -0,0 +1,522 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from fla.models.transformer.configuration_transformer import TransformerConfig +from fla.modules import FusedCrossEntropyLoss, RMSNorm, RotaryEmbedding +from fla.modules.activations import swiglu_linear + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import (index_first_axis, pad_input, + unpad_input) +except ImportError: + warnings.warn("Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`") + flash_attn_func = None +logger = logging.get_logger(__name__) + + +class TransformerAttention(nn.Module): + + def __init__( + self, + config: TransformerConfig, + layer_idx: Optional[int] = None, + **kwargs + ): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.num_heads = config.num_heads + if config.num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = config.num_kv_heads + self.num_kv_groups = config.num_heads // self.num_kv_heads + self.hidden_size = config.hidden_size + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.kv_dim = self.num_kv_heads * self.head_dim + self.max_position_embeddings = config.max_position_embeddings + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + self.rotary = RotaryEmbedding(self.head_dim) + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, q_len, _ = hidden_states.size() + q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads) + k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', h=self.num_kv_heads) + v = rearrange(self.v_proj(hidden_states), 'b t (h d) -> b h t d', h=self.num_kv_heads) + + seqlen_offset = 0 + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1] + q, k = self.rotary(q, k, seqlen_offset, self.max_position_embeddings) + + k = rearrange(k, 'b t h d -> b h t d') + if past_key_values is not None: + k, v = past_key_values.update(k, v, self.layer_idx) + k, v = rearrange(k, 'b h t d -> b t h d'), rearrange(v, 'b h t d -> b t h d') + if self.num_kv_groups > 1: + k = rearrange(k.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d') + v = rearrange(v.unsqueeze(-2).repeat(1, 1, 1, self.num_kv_groups, 1), 'b t h g d -> b t (h g) d') + + if flash_attn_func is None: + raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") + + # Contains at least one padding token in the sequence + if attention_mask is not None: + q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_q, max_seqlen_k = max_seq_lens + o = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + causal=True + ) + o = pad_input(o, indices_q, batch_size, q_len) + else: + o = flash_attn_func(q, k, v, causal=True) + o = o.reshape(batch_size, q_len, self.hidden_size) + o = self.o_proj(o) + + if not output_attentions: + attentions = None + + return o, attentions, past_key_values + + def _upad_input(self, q, k, v, attention_mask, q_len): + seqlens = attention_mask.sum(-1, dtype=torch.int32) + indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_k = seqlens.max().item() + cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) + batch_size, seq_len, num_key_value_heads, head_dim = k.shape + + k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) + v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) + if q_len == seq_len: + q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_q = max_seqlen_k + indices_q = indices_k + elif q_len == 1: + max_seqlen_q = 1 + # There is a memcpy here, that is very bad. + cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) + indices_q = cu_seqlens_q[:-1] + q = q.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -q_len:] + q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask) + + return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +class TransformerMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish' + ) -> TransformerMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + y = self.gate_proj(x) + gate, y = y.chunk(2, -1) + return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + + +class TransformerBlock(nn.Module): + def __init__(self, config: TransformerConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.attn = TransformerAttention( + config=config, + layer_idx=layer_idx + ) + self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + self.mlp = TransformerMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attentions,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + +class TransformerPreTrainedModel(PreTrainedModel): + + config_class = TransformerConfig + supports_gradient_checkpointing = True + _no_split_modules = ['TransformerBlock'] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["o_proj.weight", "down_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class TransformerModel(TransformerPreTrainedModel): + + def __init__(self, config: TransformerConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, CausalLMOutputWithPast]: + if output_attentions: + warnings.warn( + "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + next_decoder_cache = None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + output_attentions, + use_cache + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class TransformerForCausalLM(TransformerPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = TransformerModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if self.config.fuse_cross_entropy: + loss_fct = FusedCrossEntropyLoss(inplace_backward=True) + else: + loss_fct = nn.CrossEntropyLoss() + # Enable model parallelism + labels = labels.to(logits.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1) + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/finetune/lora/v6/fla/models/utils.py b/finetune/lora/v6/fla/models/utils.py new file mode 100644 index 0000000..ed51f7d --- /dev/null +++ b/finetune/lora/v6/fla/models/utils.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +import torch +from transformers.cache_utils import Cache + + +class RecurrentCache(Cache): + """ + A cache used for storing hidden states produced by flash linear attention models. + + It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`. + """ + + def __init__( + self, + seen_tokens: int = 0 + ) -> RecurrentCache: + + self.states: List[torch.Tensor] = [] + self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen + + def __getitem__(self, layer_idx: int) -> torch.Tensor: + if layer_idx < len(self): + return self.states[layer_idx] + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + for state in self.states: + yield state + + def __len__(self): + return len(self.states) + + def update( + self, + state: Tuple[torch.Tensor], + layer_idx: int, + offset: Optional[int] = 1, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor]: + """ + Updates the cache with the new `state` for the layer `layer_idx`. + + Parameters: + state (`Tuple[torch.Tensor]`): + The new state to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + offset (`int`): + The offset of current fed tokens. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. + + Return: + The updated state. + """ + + if isinstance(state, torch.Tensor): + state = (state,) + if len(self.states) <= layer_idx: + self.states.append(state) + else: + for i, s in enumerate(state): + self.states[layer_idx][i].copy_(s) + # update the number of seen tokens once we achieve the last layer + if layer_idx == len(self) - 1: + self._seen_tokens += offset + + return state + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.states) <= layer_idx: + return 0 + return self._seen_tokens + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. RecurrentCache does not have a maximum length.""" + return None + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.states)): + device = self.states[layer_idx].device + self.states[layer_idx] = self.states[layer_idx].index_select(0, beam_idx.to(device)) + + def to_legacy_cache(self) -> Tuple[torch.Tensor]: + return tuple(self.states) + + @classmethod + def from_legacy_cache( + cls, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + seen_tokens: int = 0 + ) -> RecurrentCache: + """Converts a cache in the legacy cache format into an equivalent `RecurrentCache`.""" + + cache = cls(seen_tokens) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + cache.update(past_key_values[layer_idx], layer_idx) + return cache diff --git a/finetune/lora/v6/fla/modules/__init__.py b/finetune/lora/v6/fla/modules/__init__.py new file mode 100644 index 0000000..4874b9d --- /dev/null +++ b/finetune/lora/v6/fla/modules/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- + +from fla.modules.convolution import (ImplicitLongConvolution, LongConvolution, + ShortConvolution) +from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss +from fla.modules.fused_norm_gate import (FusedLayerNormSwishGate, + FusedLayerNormSwishGateLinear, + FusedRMSNormSwishGate, + FusedRMSNormSwishGateLinear) +from fla.modules.layernorm import (LayerNorm, LayerNormLinear, RMSNorm, + RMSNormLinear) +from fla.modules.rotary import RotaryEmbedding + +__all__ = [ + 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution', + 'FusedCrossEntropyLoss', + 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear', + 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear', + 'RotaryEmbedding' +] diff --git a/finetune/lora/v6/fla/modules/activations.py b/finetune/lora/v6/fla/modules/activations.py new file mode 100644 index 0000000..46d86c5 --- /dev/null +++ b/finetune/lora/v6/fla/modules/activations.py @@ -0,0 +1,394 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Tri Dao, Yu Zhang, Songlin Yang. + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +from fla.utils import contiguous + +sigmoid_fwd_codestring = """ +template T sigmoid_fwd(T x) { + return 1.0f / (1.0f + ::exp(-float(x))); +} +""" +sigmoid_bwd_codestring = """ +template T sigmoid_bwd(T x, T g) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + return float(g) * x_sigmoid * (1.0f - x_sigmoid); +} +""" + +sigmoid_fwd = torch.cuda.jiterator._create_jit_fn(sigmoid_fwd_codestring) +sigmoid_bwd = torch.cuda.jiterator._create_jit_fn(sigmoid_bwd_codestring) + + +class SigmoidFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return sigmoid_fwd(x) + + @staticmethod + def backward(ctx, dout): + x, = ctx.saved_tensors + return sigmoid_bwd(x, dout) + + +sigmoid = SigmoidFunction.apply + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + triton.Config({'BT': 128}, num_warps=2), + triton.Config({'BT': 128}, num_warps=4), + triton.Config({'BT': 128}, num_warps=8), + triton.Config({'BT': 256}, num_warps=2), + triton.Config({'BT': 256}, num_warps=4), + triton.Config({'BT': 256}, num_warps=8) + ], + key=['D'] +) +@triton.jit +def logsigmoid_fwd_kernel( + x, + y, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr +): + i = tl.program_id(0) + o_i = i * BT + tl.arange(0, BT) + + p_x = x + o_i + p_y = y + o_i + mask = o_i < T + + # [D,] + b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32) + b_m = tl.minimum(0., b_x) + b_z = 1. + tl.exp(-tl.abs(b_x)) + b_y = b_m - tl.log(b_z) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + triton.Config({'BT': 128}, num_warps=2), + triton.Config({'BT': 128}, num_warps=4), + triton.Config({'BT': 128}, num_warps=8), + triton.Config({'BT': 256}, num_warps=2), + triton.Config({'BT': 256}, num_warps=4), + triton.Config({'BT': 256}, num_warps=8) + ], + key=['D'] +) +@triton.jit +def logsigmoid_bwd_kernel( + x, + dx, + dy, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr +): + i = tl.program_id(0) + o_i = i * BT + tl.arange(0, BT) + + p_x = x + o_i + p_dx = dx + o_i + p_dy = dy + o_i + mask = o_i < T + + # [D,] + b_x = tl.load(p_x, mask=mask, other=0.).to(tl.float32) + b_dy = tl.load(p_dy, mask=mask, other=0.).to(tl.float32) + b_dx = b_dy * (1. - tl.sigmoid(b_x)) + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + + +class LogSigmoidFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, x): + T, D = x.numel(), x.shape[-1] + y = torch.empty_like(x) + logsigmoid_fwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, y, T=T, D=D) + ctx.save_for_backward(x,) + return y + + @staticmethod + @contiguous + def backward(ctx, dy): + x, = ctx.saved_tensors + T, D = x.numel(), x.shape[-1] + dx = torch.empty_like(x) + logsigmoid_bwd_kernel[lambda meta: (triton.cdiv(meta['T'], meta['D']),)](x, dx, dy, T=T, D=D) + return dx + + +logsigmoid = LogSigmoidFunction.apply + +swish_fwd_codestring = """ +template T swish_fwd(T x) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + return float(x) * x_sigmoid; +} +""" +swish_bwd_codestring = """ +template T swish_bwd(T x, T g) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + return float(g) * x_sigmoid * (1.0f - float(x) * x_sigmoid + float(x)); +} +""" + +swish_fwd = torch.cuda.jiterator._create_jit_fn(swish_fwd_codestring) +swish_bwd = torch.cuda.jiterator._create_jit_fn(swish_bwd_codestring) + + +class SwishFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_fwd(x) + + @staticmethod + def backward(ctx, dout): + x, = ctx.saved_tensors + return swish_bwd(x, dout) + + +swish = SwishFunction.apply + +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 + + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.jit.script +def bias_gelu(y, bias): + x = bias + y + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_bwd(g, y, bias): + """Assume that y has shape (B, D) and bias has shape (D)""" + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) + grad_y = ff * g + return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) + + +class GeLUFunction(torch.autograd.Function): + + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(input, bias) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_bwd(grad_output, input, bias) + return tmp, tmp + + +bias_gelu_impl = GeLUFunction.apply + + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.jit.script +def gelu_fwd(x): + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def gelu_bwd(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) + return (ff * g).to(dtype=x.dtype) + + +class FastGeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input): + ctx.save_for_backward(input) + return gelu_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + tmp = gelu_bwd(grad_output, input) + return tmp + + +fast_gelu_impl = FastGeLUFunction.apply + + +@torch.jit.script +def relu_bwd(g, x): + return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_fwd(x): + r = F.relu(x) + return (r * r).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_bwd(g, x): + return (2.0 * g * F.relu(x)).to(dtype=x.dtype) + + +class SquaredReLUFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return sqrelu_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return sqrelu_bwd(grad_output, input) + + +sqrelu = SquaredReLUFunction.apply + + +swiglu_fwd_codestring = """ +template T swiglu_fwd(T x, T y) { + return float(x) * float(y) / (1.0f + ::exp(-float(x))); +} +""" +swiglu_bwd_codestring = """ +template T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); + dy = float(x) * x_sigmoid * float(g); +} +""" + +swiglu_bwd_with_output_codestring = """ +template T swiglu_bwd_with_output(T x, T y, T g, T& dx, T& dy, T& z) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + float x_swish = float(x) * x_sigmoid; + dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); + dy = x_swish * float(g); + z = x_swish * float(y); +} +""" + +swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) +swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) +swiglu_bwd_with_output = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_with_output_codestring, num_outputs=3) + + +class SwiGLUFunction(torch.autograd.Function): + r""" + Swish-Gated Linear Unit (SwiGLU) function. + + .. math:: + \text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y + """ + + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return swiglu_fwd(x, y) + + @staticmethod + def backward(ctx, dout): + x, y = ctx.saved_tensors + return swiglu_bwd(x, y, dout) + + +class SwiGLULinearFunction(torch.autograd.Function): + r""" + Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation. + + .. math:: + \text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b + + This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory. + """ + + @staticmethod + def forward(ctx, x, y, weight, bias): + z = swiglu_fwd(x, y) + out = F.linear(z.to(weight.dtype), weight, bias) + # We don't store z, will be recomputed in the backward pass to save memory + ctx.save_for_backward(x, y, weight) + ctx.linear_bias_is_none = bias is None + return out + + @staticmethod + def backward(ctx, dout, *args): + x, y, weight = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dz = F.linear(dout, weight.t()).view_as(x) + dx, dy, z = swiglu_bwd_with_output(x, y, dz) + dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1])) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + return dx, dy, dlinear_weight, dlinear_bias + + +swiglu = SwiGLUFunction.apply + +swiglu_linear = SwiGLULinearFunction.apply + +ACT2FN = { + 'relu': F.relu, + 'sigmoid': sigmoid, + 'logsigmoid': logsigmoid, + 'silu': swish, + 'swish': swish, + 'sqrelu': sqrelu, + 'gelu': fast_gelu_impl, + 'bias_gelu': bias_gelu_impl, +} diff --git a/finetune/lora/v6/fla/modules/convolution.py b/finetune/lora/v6/fla/modules/convolution.py new file mode 100644 index 0000000..3e2e07d --- /dev/null +++ b/finetune/lora/v6/fla/modules/convolution.py @@ -0,0 +1,336 @@ +# -*- coding: utf-8 -*- + +# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py + +import math +import warnings +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from fla.modules.activations import ACT2FN +from fla.utils import checkpoint + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn = None + causal_conv1d_update = None + + +def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None): + seqlen = u.shape[-1] + fft_size = 2 * seqlen + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + if k_rev is not None: + k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size + k_f = k_f + k_rev_f.conj() + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + + if len(u.shape) > 3: + k_f = k_f.unsqueeze(1) + y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] + + out = y + u + if gelu: + out = F.gelu(out) + if dropout_mask is not None: + return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) + else: + return out.to(dtype=u.dtype) + + +@checkpoint +def proj_then_conv1d( + x: torch.Tensor, + proj_weight: torch.Tensor, + conv1d_weight: torch.Tensor, + conv1d_bias: Optional[torch.Tensor] = None, + cache: Optional[torch.Tensor] = None +) -> torch.Tensor: + # We do matmul and transpose BLH -> HBL at the same time + x = rearrange(proj_weight @ rearrange(x, "b l d -> d (b l)"), "d (b l) -> b d l", l=x.shape[-2]) + + if causal_conv1d_fn is None: + raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.") + if cache is None: + x = causal_conv1d_fn( + x=x, + weight=rearrange(conv1d_weight, "d 1 w -> d w"), + bias=conv1d_bias, + activation="silu", + ).transpose(1, 2) + else: + assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now" + x = x.squeeze(-1) + x = causal_conv1d_update( + x=x, + weight=rearrange(conv1d_weight, "d 1 w -> d w"), + bias=conv1d_bias, + cache=cache, + activation="silu", + ) + return x + + +class ShortConvolution(nn.Conv1d): + """ + Simple wrapper around `nn.Conv1d` that accepts dimension last. + """ + + def __init__( + self, + hidden_size: int, + kernel_size: int, + bias: bool = False, + activation: Optional[str] = 'silu', + use_causal_conv: Optional[bool] = True + ): + super().__init__(in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size, + groups=hidden_size, + bias=bias, + padding=kernel_size - 1) + + self.hidden_size = hidden_size + self.activation = None + if activation is not None: + assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet." + self.activation = activation + + if use_causal_conv: + if causal_conv1d_fn is None: + warnings.warn("Please install `causal-conv1d` to use causal convolutions, setting `use_causal_conv` to False.") + use_causal_conv = False + self.use_causal_conv = use_causal_conv + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.output_padding != (0,) * len(self.output_padding): + s += ', output_padding={output_padding}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is None: + s += ', bias=False' + if self.padding_mode != 'zeros': + s += ', padding_mode={padding_mode}' + if self.activation is not None: + s += ', activation={activation}' + if not self.use_causal_conv: + s += ', use_causal_conv={use_causal_conv}' + return s.format(**self.__dict__) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + cache: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Args: + x (`torch.Tensor`): + Tensor of shape `[batch_size, seq_len, hidden_size]` + mask (`Optional[torch.Tensor]`): + Attention mask dealing with padded positions. + cache (`Optional[torch.Tensor]`): + Previous cache tensor of shape `[batch_size, hidden_size, kernel_size]`, + Returns: + Tensor of shape `[batch_size, seq_len, hidden_size]`. The `cache` (if provided) is updated inplace. + """ + + if mask is not None: + x = x.mul_(mask.unsqueeze(-1)) + if cache is not None and x.shape[1] == 1: + return self.step(x, cache) + x = rearrange(x, "b l d -> b d l") + # Update state (B D W) + if cache is not None: + cache.copy_(F.pad(x, (self.kernel_size[0] - x.shape[-1], 0))) + if self.use_causal_conv: + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.weight, "d 1 w -> d w"), + bias=self.bias, + activation=self.activation, + ) + else: + x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]] + if self.activation is not None: + x = ACT2FN[self.activation](x) + return rearrange(x, "b d l -> b l d") + + def step( + self, + x: torch.Tensor, + cache: torch.Tensor + ): + assert x.shape[1] == 1, "Only support decoding with 1 token at a time for now" + + x = x.squeeze(1) + if self.use_causal_conv: + x = causal_conv1d_update( + x=x, + conv_state=cache, + weight=rearrange(self.weight, "d 1 w -> d w"), + bias=self.bias, + activation=self.activation, + ) + else: + dtype = x.dtype + cache.copy_(torch.roll(cache, shifts=-1, dims=-1)) + cache[:, :, -1] = x + x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1) + if self.bias is not None: + x = x + self.bias + if self.activation is not None: + x = ACT2FN[self.activation](x).to(dtype=dtype) + return x.unsqueeze(1) + + @property + def state_size(self) -> int: + return self.hidden_size * self.kernel_size + + +class LongConvolution(nn.Module): + """ + LongConvolution applies a convolution operation on the input tensor using a fixed + filter of length l_max. + The filter is learned during training and is applied using FFT convolution. + Args: + hidden_size (int): The number of expected features in the input and output. + l_max (int): The maximum sequence length. + Returns: + y: (b, l, d) tensor + """ + + def __init__( + self, + hidden_size: int, + l_max: int, + **kwargs, + ): + """ + Initializes the LongConvolution module. + Args: + hidden_size (int): The number of expected features in the input and output. + l_max (int): The maximum sequence length. + """ + super().__init__() + self.hidden_size = hidden_size + self.filter = nn.Parameter(torch.randn(self.hidden_size, l_max), requires_grad=True) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Applies the LongConvolution operation on the input tensor. + Args: + x: (b, l, d) tensor + Returns: + y: (b, l, d) tensor + """ + x = x.transpose(1, 2) + y = fft_conv(x, self.filter, dropout_mask=None, gelu=False) + y = y.transpose(1, 2) + return y.to(dtype=x.dtype) + + +class PositionalEmbedding(nn.Module): + def __init__(self, emb_dim: int, seq_len: int, **kwargs): + """Complex exponential positional embeddings for implicit long convolution filters.""" + super().__init__() + + self.seq_len = seq_len + # The time embedding fed to the filteres is normalized so that t_f = 1 + t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 + + if emb_dim > 1: + bands = (emb_dim - 1) // 2 + # To compute the right embeddings we use the "proper" linspace + t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] + w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 + + f = torch.linspace(1e-4, bands - 1, bands)[None, None] + z = torch.exp(-1j * f * w) + z = torch.cat([t, z.real, z.imag], dim=-1) + self.z = nn.Parameter(z, requires_grad=False) + + def forward(self, L): + return self.z[:, :L] + + +class ImplicitLongConvolution(nn.Module): + """ + Long convolution with implicit filter parameterized by an MLP. + + Args: + hidden_size (int): + The number of expected features in the input and output. + l_max (int): + The maximum sequence length. + d_emb (Optional[int]): + The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine). + Defaults to 3. + d_hidden (Optional[int]): + The number of features in the hidden layer of the MLP. Defaults to 16. + + Attributes: + pos_emb (`PositionalEmbedding`): The positional embedding layer. + mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter. + + """ + + def __init__( + self, + hidden_size: int, + l_max: int, + d_emb: int = 3, + d_hidden: int = 16, + **kwargs, + ): + """ + Long convolution with implicit filter parameterized by an MLP. + + + """ + super().__init__() + self.hidden_size = hidden_size + self.d_emb = d_emb + + assert ( + d_emb % 2 != 0 and d_emb >= 3 + ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)" + self.pos_emb = PositionalEmbedding(d_emb, l_max) + + # final linear layer + self.mlp = nn.Sequential( + nn.Linear(d_emb, d_hidden), + torch.nn.ReLU(), + nn.Linear(d_hidden, hidden_size), + ) + + def filter(self, seq_len: int, *args, **kwargs): + k = self.mlp(self.pos_emb(seq_len)) + + return k.transpose(1, 2) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Args: + x: (b, l, d) tensor + Returns: + y: (b, l, d) tensor + """ + x = x.transpose(1, 2) + k = self.filter(x.shape[-1]) + y = fft_conv(x, k, dropout_mask=None, gelu=False) + + y = y.transpose(1, 2) + return y.to(dtype=x.dtype) diff --git a/finetune/lora/v6/fla/modules/feature_map.py b/finetune/lora/v6/fla/modules/feature_map.py new file mode 100644 index 0000000..43c3bb1 --- /dev/null +++ b/finetune/lora/v6/fla/modules/feature_map.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from fla.modules.layernorm import layer_norm_fn +from fla.utils import checkpoint + + +@checkpoint +def flatten_diag_outer_product(x, y): + z = torch.einsum("...i,...j->...ij", x, y) + N = z.size(-1) + indicies = torch.triu_indices(N, N) + return z[..., indicies[0], indicies[1]] + + +@checkpoint +def flatten_diag_outer_product_off1(x, y): + z = torch.einsum("...i,...j->...ij", x, y) + N = z.size(-1) + indicies = torch.triu_indices(N, N, 1) + indices2 = torch.arange(0, N) + return z[..., indicies[0], indicies[1]], z[..., indices2, indices2] + + +def is_power_of_2(n): + return (n & (n - 1) == 0) and n != 0 + + +class HedgehogFeatureMap(nn.Module): + + r""" + Hedgehog feature map as introduced in + `The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry `_ + """ + + def __init__( + self, + head_dim: int + ) -> HedgehogFeatureMap: + super().__init__() + # Trainable map + self.layer = nn.Linear(head_dim, head_dim) + self.init_weights_() + + def init_weights_(self): + """Initialize trainable map as identity""" + with torch.no_grad(): + identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float) + self.layer.weight.copy_(identity.to(self.layer.weight)) + nn.init.zeros_(self.layer.bias) + + def forward(self, x: torch.Tensor): + x = self.layer(x) # shape b, h, l, d + return torch.cat([2*x, -2*x], dim=-1).softmax(-1) + + +class T2RFeatureMap(nn.Module): + + r""" + Simple linear mapping feature map as in + `Finetuning Pretrained Transformers into RNNs `_ + """ + + def __init__( + self, + head_dim: int, + dot_dim: int = None + ) -> T2RFeatureMap: + super().__init__() + # Trainable map + if dot_dim is None: + dot_dim = head_dim + self.layer = nn.Linear(head_dim, dot_dim) + + def forward(self, x: torch.Tensor): + return self.layer(x).relu() + + +class DPFPFeatureMap(nn.Module): + + r""" + Deterministic Parameter-Free Projection (DPFP) feature map in + `Linear Transformers Are Secretly Fast Weight Programmers `_ + """ + + def __init__( + self, + head_dim: int, + nu: int = 4 + ) -> DPFPFeatureMap: + super().__init__() + self.nu = nu + + def forward(self, x: torch.Tensor): + x = torch.cat([x.relu(), -x.relu()], dim=-1) + x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1) + x_repeat = torch.cat([x] * self.nu, dim=-1) + return x_repeat * x_rolled + + +class HadamardFeatureMap(nn.Module): + def __init__( + self, + head_dim: int + ) -> HadamardFeatureMap: + super().__init__() + # Trainable map + self.layer1 = nn.Linear(head_dim, head_dim) + self.layer2 = nn.Linear(head_dim, head_dim) + + def forward(self, x: torch.Tensor): + return self.layer1(x) * self.layer2(x) + + +class LearnableOuterProductFeatureMap(nn.Module): + def __init__( + self, + head_dim: int, + feature_dim: int + ) -> LearnableOuterProductFeatureMap: + super().__init__() + # Trainable map + self.layer1 = nn.Linear(head_dim, feature_dim, bias=False) + self.layer2 = nn.Linear(head_dim, feature_dim, bias=False) + self.normalizer = feature_dim ** -0.5 + + def forward(self, x: torch.Tensor): + return flatten_diag_outer_product(self.layer1(x), self.layer2(x)) + + +class LearnablePolySketchNonNegativeFeatureMap(nn.Module): + + def __init__( + self, + head_dim: int, + sketch_size: Optional[int] = None, + degree: Optional[int] = 2 + ) -> LearnablePolySketchNonNegativeFeatureMap: + super().__init__() + + assert is_power_of_2(degree) and degree >= 2, f"The degree {degree} must be a power of 2" + + self.head_dim = head_dim + self.sketch_size = sketch_size if sketch_size is not None else head_dim + self.degree = degree + + self.gamma = nn.Parameter(torch.ones(head_dim)) + self.beta = nn.Parameter(torch.zeros(head_dim)) + # NOTE: the sketch layers defined here are quite different from the original paper + # currently we simply use linear layers without any non-linear activations + self.sketches1 = nn.ModuleList([ + nn.Linear(head_dim, sketch_size, bias=False), + *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)] + ]) + self.sketches2 = nn.ModuleList([ + nn.Linear(head_dim, sketch_size, bias=False), + *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)] + ]) + + def forward(self, x: torch.Tensor): + # Section 2.1 + x = layer_norm_fn(x, self.gamma, self.beta) + # first map the input to sketch size with learnable parameters + x = self.sketches1[0](x) * self.sketches2[0](x) * self.head_dim ** -0.5 + for i in range(1, int(math.log2(self.degree)) - 1): + x = self.sketches1[i](x) * self.sketches2[i](x) * self.head_dim ** -0.5 + # do sketch mapping for log2(p) - 1 times in total + # do p=2 mapping to ensure non-negativity + return flatten_diag_outer_product(x, x) + + +class TaylorFeatureMap(nn.Module): + def __init__( + self, + head_dim: int + ) -> TaylorFeatureMap: + super().__init__() + self.head_dim = head_dim + self.r2 = math.sqrt(2) + self.rd = math.sqrt(self.head_dim) + self.rrd = math.sqrt(self.rd) + + def forward(self, x: torch.Tensor): + x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) + return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1) + + +class RebasedFeatureMap(nn.Module): + + def __init__( + self, + head_dim: int, + use_gamma: Optional[bool] = True, + use_beta: Optional[bool] = True, + normalize: Optional[bool] = True + ) -> RebasedFeatureMap: + super().__init__() + + self.head_dim = head_dim + self.use_gamma = use_gamma + self.use_beta = use_beta + self.normalize = normalize + + self.gamma = None + self.beta = None + if use_gamma: + self.gamma = nn.Parameter(torch.ones(head_dim)) + if use_beta: + self.beta = nn.Parameter(torch.zeros(head_dim)) + + def forward(self, x: torch.Tensor, flatten: Optional[bool] = True): + if self.use_beta and self.use_gamma and self.normalize: + x = layer_norm_fn(x, self.gamma, self.beta) + elif self.normalize: + x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta) + elif self.use_gamma and self.use_beta: + x = torch.addcmul(self.beta, x, self.gamma) + elif self.use_gamma: + x = x.mul(self.gamma) + else: + raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, " + f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)") + if not flatten: + return x + x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) + # rebased use learnable parameters to approximate any quadratic function + return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1) diff --git a/finetune/lora/v6/fla/modules/fused_cross_entropy.py b/finetune/lora/v6/fla/modules/fused_cross_entropy.py new file mode 100644 index 0000000..3364680 --- /dev/null +++ b/finetune/lora/v6/fla/modules/fused_cross_entropy.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. + +from typing import Tuple + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +@triton.heuristics( + { + "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, + } +) +@triton.jit +def cross_entropy_fwd_kernel( + loss_ptr, # data ptrs + lse_ptr, + z_loss_ptr, + logits_ptr, + labels_ptr, + smoothing, + logit_scale, + lse_square_scale, + ignored_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + n_rows, + logits_row_stride, # strides + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, + # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE + SPLIT: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + max_logits = tl.max(logits, 0) + if HAS_SMOOTHING: + sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) + lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits + tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) + if label_idx == ignored_index: + loss = 0.0 + z_loss = 0.0 + else: + label_idx -= class_start_idx + if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( + n_cols, (col_block_idx + 1) * BLOCK_SIZE + ): + logits_label = tl.load(logits_ptr + label_idx) * logit_scale + if HAS_SMOOTHING: + loss = ( + (lse if not SPLIT else 0.0) + - smoothing * sum_logits / total_classes + - (1 - smoothing) * logits_label + ) + else: + loss = (lse if not SPLIT else 0.0) - logits_label + else: + # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss + if HAS_SMOOTHING: + loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) + else: + loss = 0.0 + if not SPLIT: + z_loss = lse_square_scale * lse * lse + loss += z_loss + else: + z_loss = 0.0 + tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss) + if not SPLIT: + tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss) + + +@triton.heuristics( + { + "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, + } +) +@triton.jit +def cross_entropy_bwd_kernel( + dlogits_ptr, # data ptrs + dloss_ptr, + logits_ptr, + lse_ptr, + labels_ptr, + smoothing, + logit_scale, + lse_square_scale, + ignored_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + logits_row_stride, # strides + dlogits_row_stride, + dloss_row_stride, + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + if label_idx != ignored_index: + dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) + else: + dloss = 0.0 + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + lse = tl.load(lse_ptr + row_idx) + probs = tl.exp(logits - lse) + probs += 2.0 * lse_square_scale * lse * probs + label_idx -= class_start_idx + if HAS_SMOOTHING: + smooth_negative = smoothing / total_classes + probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative + else: + probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) + tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) + + +class CrossEntropyLossFunction(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + logits, + labels, + smoothing=0.0, + logit_scale=1.0, + lse_square_scale=0.0, + ignored_index=-100, + inplace_backward=False, + process_group=None, + ): + n_rows, n_cols = logits.shape + assert labels.shape == (n_rows,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + total_classes = world_size * n_cols + rank = 0 if process_group is None else torch.distributed.get_rank(process_group) + class_start_idx = rank * n_cols + + if logits.stride(-1) != 1: + logits = logits.contiguous() + # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py + MAX_BLOCK_SIZE = 64 * 1024 + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) + num_warps = ( + 4 + if BLOCK_SIZE < 2048 + else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) + ) + # We may split the lse computation across multiple blocks, then do a reduction + # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k) + # where having just one thread block processing more than 64k elements is slow. + split = world_size > 1 or n_cols > MAX_BLOCK_SIZE + n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE + loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,) + losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(logits.device.index): + cross_entropy_fwd_kernel[(n_rows, n_splits)]( + losses, # data ptrs + lse, + z_losses, + logits, + labels, + smoothing, + logit_scale, + lse_square_scale, + ignored_index, + total_classes, + class_start_idx, + n_cols, # shapes + n_rows, + logits.stride(0), # strides + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + SPLIT=split, + ) + + if split: + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # -0.9 * predicted logit - 0.1 * sum logit / total_classes. + # For labels not in the vocab of this partition, losses contains + # -0.1 * sum logit / total_classes. + if n_splits > 1: + lse = torch.logsumexp(lse, dim=0) + losses = losses.sum(dim=0) + if world_size > 1: + lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + handle_losses.wait() + # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, + # we just have to add the (global) lse. + # If there's smoothing=0.1, the total losses are + # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. + # Again, we just have to add the (global) lse. + losses += lse + if lse_square_scale != 0.0: + z_losses = lse_square_scale * lse.square() + z_losses.masked_fill_(labels == ignored_index, 0.0) + losses += z_losses + else: + z_losses = torch.zeros_like(losses) + losses.masked_fill_(labels == ignored_index, 0.0) + + ctx.save_for_backward(logits, lse, labels) + ctx.mark_non_differentiable(z_losses) + ctx.smoothing = smoothing + ctx.logit_scale = logit_scale + ctx.lse_square_scale = lse_square_scale + ctx.ignored_index = ignored_index + ctx.total_classes = total_classes + ctx.class_start_idx = class_start_idx + ctx.inplace_backward = inplace_backward + + return losses, z_losses + + @staticmethod + def backward(ctx, grad_losses, grad_z_losses): + del grad_z_losses # z_losses are only for logging. + + logits, lse, labels = ctx.saved_tensors + dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) + n_rows, n_cols = logits.shape + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) + num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) + def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(logits.device.index): + cross_entropy_bwd_kernel[grid]( + dlogits, # data ptrs + grad_losses, + logits, + lse, + labels, + ctx.smoothing, + ctx.logit_scale, + ctx.lse_square_scale, + ctx.ignored_index, + ctx.total_classes, + ctx.class_start_idx, + n_cols, # shapes + logits.stride(0), # strides + dlogits.stride(0), + grad_losses.stride(0), + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + ) + return dlogits, None, None, None, None, None, None, None, None + + +def cross_entropy_loss( + logits: torch.Tensor, + labels: torch.Tensor, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + ignored_index=-100, + inplace_backward: bool = False, + process_group=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + logits: (batch, vocab_size) + labels: (batch,) + label_smoothing: float + logit_scale: float. Multiply logits by this scale before calculating the loss. + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + ignored_index: int. If labels == ignored_index, the loss is set to 0.0. + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + Returns: + losses: (batch,), float + z_losses: (batch,), float + """ + return CrossEntropyLossFunction.apply( + logits, + labels, + label_smoothing, + logit_scale, + lse_square_scale, + ignored_index, + inplace_backward, + process_group, + ) + + +class FusedCrossEntropyLoss(nn.Module): + def __init__( + self, + ignore_index=-100, + reduction="mean", + label_smoothing=0.0, + logit_scale=1.0, + lse_square_scale=0.0, + inplace_backward=False, + process_group=None, + return_z_loss=False, + ): + """ + Arguments: + ignored_index: int. If labels == ignored_index, the loss is set to 0.0. + label_smoothing: float + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + return_z_loss: bool. If True, we return the component of the loss contributed by + the lse_square_scale value. This value is only for logging and does not support + backprop. + """ + super().__init__() + if reduction not in ["mean", "none", "sum"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.logit_scale = logit_scale + self.lse_square_scale = lse_square_scale + self.inplace_backward = inplace_backward + self.process_group = process_group + self.return_z_loss = return_z_loss + + def forward(self, input, target): + """ + Arguments: + input: (batch, vocab_size) + target: (batch,) + Returns: + losses: (batch,) if reduction is 'none', else (1,), dtype float + z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) + """ + assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" + loss, z_loss = cross_entropy_loss( + input, + target, + label_smoothing=self.label_smoothing, + logit_scale=self.logit_scale, + lse_square_scale=self.lse_square_scale, + ignored_index=self.ignore_index, + inplace_backward=self.inplace_backward, + process_group=self.process_group, + ) + if self.reduction == "mean": + loss = loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + loss = loss.sum() + else: + loss = loss + + if not self.return_z_loss: + return loss + + if self.reduction == "mean": + z_loss = z_loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + z_loss = z_loss.sum() + else: + z_loss = z_loss + + return loss, z_loss diff --git a/finetune/lora/v6/fla/modules/fused_norm_gate.py b/finetune/lora/v6/fla/modules/fused_norm_gate.py new file mode 100644 index 0000000..739b5ae --- /dev/null +++ b/finetune/lora/v6/fla/modules/fused_norm_gate.py @@ -0,0 +1,889 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. +# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from fla.utils import contiguous + + +def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + return out if not prenorm else (out, x) + + +def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + \ + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if not prenorm else (out, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + O, # pointer to the gate + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + O += row * stride_x_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < + N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w if HAS_WEIGHT else x_hat + if HAS_BIAS: + y = y + b + + # Swish output gate + o = tl.load(O + cols, mask=cols < N, other=0.0).to(tl.float32) + y = y * o * tl.sigmoid(o) + + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, o, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + assert y.stride(-1) == 1 + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, + device="cuda") if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + o, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + O, # pointer to the gate + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DO, # pointer to the gate gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + O += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + DO += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + o = tl.load(O + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + + y = xhat * w if HAS_WEIGHT else xhat + if HAS_BIAS: + y = y + b + if RECOMPUTE_OUTPUT: + tl.store(Y + cols, y, mask=mask) + + sigmoid_o = tl.sigmoid(o) + do = dy * y * (sigmoid_o + o * sigmoid_o * (1 - sigmoid_o)) + dy = dy * o * sigmoid_o + wdy = dy + if HAS_WEIGHT: + wdy = dy * w + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + tl.store(DO + cols, do, mask=mask) + + X += stride_x_row + O += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + DO += stride_dx_row + if HAS_WEIGHT: + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + o, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + has_residual=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + do = ( + torch.empty_like(o) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = ( + torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + if weight is not None + else None + ) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + o, + weight, + bias, + y, + dy, + dx, + do, + _dw, + _db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + weight is not None, + bias is not None, + ) + dw = _dw.sum(0).to(weight.dtype) if weight is not None else None + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, do, dw, db, dresidual_in) if not recompute_output else (dx, do, dw, db, dresidual_in, y) + + +class LayerNormSwishGateFn(torch.autograd.Function): + + @staticmethod + @contiguous + def forward( + ctx, + x, + o, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + o_shape_og = o.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + o = o.reshape(-1, o.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, o, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm + ) + ctx.save_for_backward(residual_out, o, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.o_shape_og = o_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + @contiguous + def backward(ctx, dy, *args): + x, o, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, do, dw, db, dresidual_in = _layer_norm_bwd( + dy, + x, + o, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + do.reshape(ctx.o_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +class LayerNormSwishGateLinearFn(torch.autograd.Function): + + @staticmethod + @contiguous + def forward( + ctx, + x, + o, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + o_shape_og = o.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + o = o.reshape(-1, o.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, + o, + norm_weight, + norm_bias, + eps, + residual, + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, o, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.o_shape_og = o_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @contiguous + def backward(ctx, dout, *args): + x, o, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, do, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + o, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + do.reshape(ctx.o_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_swish_gate_fn( + x, + o, + weight, + bias, + residual=None, + prenorm=False, + residual_in_fp32=False, + eps=1e-6 +): + return LayerNormSwishGateFn.apply( + x, + o, + weight, + bias, + residual, + eps, + prenorm, + residual_in_fp32, + False + ) + + +def rms_norm_swish_gate_fn( + x, + o, + weight, + bias, + residual=None, + prenorm=False, + residual_in_fp32=False, + eps=1e-6 +): + return LayerNormSwishGateFn.apply( + x, + o, + weight, + bias, + residual, + eps, + prenorm, + residual_in_fp32, + True + ) + + +def layer_norm_swish_gate_linear_fn( + x, + o, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + prenorm=False, + residual_in_fp32=False, + eps=1e-6 +): + return LayerNormSwishGateLinearFn.apply( + x, + o, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + False + ) + + +def rms_norm_swish_gate_linear_fn( + x, + o, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + prenorm=False, + residual_in_fp32=False, + eps=1e-6 +): + return LayerNormSwishGateLinearFn.apply( + x, + o, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + True + ) + + +class FusedLayerNormSwishGate(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps=1e-5 + ) -> FusedLayerNormSwishGate: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_swish_gate_fn( + x, + o, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class FusedRMSNormSwishGate(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps=1e-5 + ) -> FusedRMSNormSwishGate: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, o, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_swish_gate_fn( + x, + o, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class FusedLayerNormSwishGateLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps=1e-5 + ) -> FusedLayerNormSwishGateLinear: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_swish_gate_linear_fn( + x, + o, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class FusedRMSNormSwishGateLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps=1e-5 + ) -> FusedRMSNormSwishGateLinear: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, o, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_swish_gate_linear_fn( + x, + o, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) diff --git a/finetune/lora/v6/fla/modules/l2norm.py b/finetune/lora/v6/fla/modules/l2norm.py new file mode 100644 index 0000000..9af045f --- /dev/null +++ b/finetune/lora/v6/fla/modules/l2norm.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +import math +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_fwd, custom_bwd +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _l2_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + stride_x_row, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_x_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) + rstd = 1 / tl.sqrt(var + eps) + # tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + y = x * rstd + # Write output + tl.store(Y + cols, y, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +# @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _l2_norm_bwd_kernel( + X, # pointer to the input + # Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + stride_x_row, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + DX += row * stride_x_row + DY += row * stride_x_row + + # Y += row * stride_y_row + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x, 0.0) + var = tl.sum(x * x) + rstd = 1 / tl.sqrt(var + eps) + # tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + # y = x * rstd + dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32) + dy = tl.where(cols < N, dy, 0.0) + # dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x + dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x + tl.store(DX + cols, dx, mask=mask) + +def _l2_norm_fwd( + x, eps=1e-6 +): + x_shape_og = x.shape + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + M, N = x.shape + assert x.stride(-1) == 1 + # allocate output + y = torch.empty_like(x) + assert y.stride(-1) == 1 + N = x.shape[-1] + M = x.shape[0] + # rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _l2_norm_fwd_1pass_kernel[(M,)]( + x, + y, + x.stride(0), + N, + eps, + # is_rms_norm, + BLOCK_N, + # residual is not None, + # residual_out is not None, + # bias is not None, + ) + return y.reshape(x_shape_og) + +def _l2_norm_bwd( + x, dy, eps=1e-5, +): + x_shape_og = x.shape + x = x.reshape(-1, dy.shape[-1]) + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + # allocate output + dx = torch.empty_like(x) + N = x.shape[-1] + M = x.shape[0] + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + # rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _l2_norm_bwd_kernel[(M,)]( + x, + dy, + dx, + x.stride(0), + N, + eps, + BLOCK_N, + ) + return dx.reshape(x_shape_og) + + +class L2NormFN(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + eps=1e-6, + ): + # reshape input data into 2D tensor + y = _l2_norm_fwd(x, eps) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.x_dtype = x.dtype + ctx.save_for_backward(x) + return y + + @staticmethod + def backward(ctx, dy, *args): + x, = ctx.saved_tensors + dx = _l2_norm_bwd( + x, + dy, + ctx.eps, + ) + return ( + dx, + None + ) + +l2_norm_fn = L2NormFN.apply + +if __name__ == '__main__': + x = torch.rand(10, 10, 100).cuda().requires_grad_(True) + y = torch.nn.functional.normalize(x, dim=-1, p=2) + dy = torch.rand_like(y) + y.backward(dy, retain_graph=True) + x_grad, x.grad = x.grad, None + y2 = l2_norm_fn(x, 1e-6) + print((y-y2).abs().max()) + y2.backward(dy, retain_graph=True) + x_grad2, x.grad = x.grad, None + print((x_grad2-x_grad).abs().max()) + breakpoint() + + + + diff --git a/finetune/lora/v6/fla/modules/layernorm.py b/finetune/lora/v6/fla/modules/layernorm.py new file mode 100644 index 0000000..9bd7477 --- /dev/null +++ b/finetune/lora/v6/fla/modules/layernorm.py @@ -0,0 +1,802 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. +# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from fla.utils import contiguous + + +def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + return out if not prenorm else (out, x) + + +def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + \ + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if not prenorm else (out, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < + N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + + y = x_hat * w if HAS_WEIGHT else x_hat + if HAS_BIAS: + y = y + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + assert y.stride(-1) == 1 + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, + device="cuda") if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w if HAS_WEIGHT else xhat + if HAS_BIAS: + y = y + b + tl.store(Y + cols, y, mask=mask) + wdy = dy + if HAS_WEIGHT: + wdy = dy * w + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_WEIGHT: + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + has_residual=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = torch.empty_like( + x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, + device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError( + "This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = ( + torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + if weight is not None + else None + ) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + weight is not None, + bias is not None, + ) + dw = _dw.sum(0).to(weight.dtype) if weight is not None else None + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + + +class LayerNormFn(torch.autograd.Function): + + @staticmethod + @contiguous + def forward( + ctx, + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm + ) + ctx.save_for_backward(residual_out, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + @contiguous + def backward(ctx, dy, *args): + x, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + prenorm=False, + residual_in_fp32=False, + eps=1e-6 +): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) + + +class LayerNorm(nn.Module): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + eps: float = 1e-5 + ) -> LayerNorm: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class RMSNorm(nn.Module): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + eps: float = 1e-5 + ) -> RMSNorm: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + + @staticmethod + @contiguous + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to( + dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, + norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @contiguous + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) + + +class LayerNormLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps=1e-5 + ) -> LayerNormLinear: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_linear_fn( + x, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=False + ) + + +class RMSNormLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps=1e-5 + ) -> RMSNormLinear: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(hidden_size)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_linear_fn( + x, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=True + ) diff --git a/finetune/lora/v6/fla/modules/rotary.py b/finetune/lora/v6/fla/modules/rotary.py new file mode 100644 index 0000000..f77e5ee --- /dev/null +++ b/finetune/lora/v6/fla/modules/rotary.py @@ -0,0 +1,310 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. + +from typing import Optional, Tuple, Union + +import torch +from einops import rearrange, repeat + +from fla.ops.rotary import apply_rotary + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], + dim=-1, + ) + + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + # Can't save int with save_for_backward + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply( + x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen + ) + + +# For backward compatibility +apply_rotary_emb_func = apply_rotary_emb + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, + dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) + ) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, + dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, + device=self.scale.device) + - seqlen // 2 + ) / self.scale_base + scale = self.scale.to( + device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. + """ + seqlen = q.shape[1] + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=q.device, dtype=q.dtype) + if self.scale is None: + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + k = apply_rotary_emb_func( + k, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + + else: + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + k = apply_rotary_emb_func( + k, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + + return q, k diff --git a/finetune/lora/v6/fla/ops/__init__.py b/finetune/lora/v6/fla/ops/__init__.py new file mode 100644 index 0000000..4f8681d --- /dev/null +++ b/finetune/lora/v6/fla/ops/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- + +from .based import fused_chunk_based, parallel_based +from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla +from .retention import (chunk_retention, fused_chunk_retention, + fused_recurrent_retention, parallel_retention) + +__all__ = [ + 'fused_chunk_based', + 'parallel_based', + 'chunk_gla', + 'fused_chunk_gla', + 'fused_recurrent_gla', + 'chunk_retention', + 'fused_chunk_retention', + 'fused_recurrent_retention', + 'parallel_retention' +] diff --git a/finetune/lora/v6/fla/ops/abc/__init__.py b/finetune/lora/v6/fla/ops/abc/__init__.py new file mode 100644 index 0000000..1fa366a --- /dev/null +++ b/finetune/lora/v6/fla/ops/abc/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_abc +from .chunk_gate import chunk_gated_abc +from .recurrent_fuse import fused_recurrent_gated_abc + +__all__ = [ + 'chunk_abc', + 'chunk_gated_abc', + 'fused_recurrent_gated_abc' +] diff --git a/finetune/lora/v6/fla/ops/abc/chunk.py b/finetune/lora/v6/fla/ops/abc/chunk.py new file mode 100644 index 0000000..599317e --- /dev/null +++ b/finetune/lora/v6/fla/ops/abc/chunk.py @@ -0,0 +1,1194 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import (logcumsumexp_fwd_kernel, softmax_bwd_kernel, + softmax_fwd_kernel) +from fla.utils import contiguous + + +@triton.jit +def chunk_abc_fwd_kernel_h( + k, + v, + z, + h, + h0, + ht, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NORMK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + if NORMK: + p_z0 = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_k * BK,), (BK,), (0,)) + else: + p_z0 = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_v * BV,), (BV,), (0,)) + b_zp = tl.load(p_z0).to(tl.float32) + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + if NORMK: + p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = tl.exp(b_zp - b_zc), b_zc + # [BK, BV] + b_h = b_h * b_r[:, None] + b_k = tl.exp(b_k - b_zc[:, None]).to(b_k.dtype) + else: + p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = tl.exp(b_zp - b_zc), b_zc + # [BK, BV] + b_h = b_h * b_r[None, :] + b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype) + # [BK, BV] + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_fwd_kernel_intra_K( + v, + z, + o, + A, + s_v_h, + s_v_t, + s_v_d, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_o = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(0, i_i): + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, tl.exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_o *= tl.exp(b_zn[None, :] - b_z) + + o_i = tl.arange(0, BC) + o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(A + o_A + j, mask=m_A, other=0) + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + # [BC, BV] + # avoid 0 * inf = inf + m_i = o_i[:, None] >= j + b_o += tl.where(m_i, b_A[:, None] * tl.exp(b_v[None, :] - b_z), 0) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_fwd_kernel_K( + q, + k, + z, + h, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_o += tl.dot(b_q, b_h, allow_tf32=False) + # [BT, BT] + b_A += tl.dot(b_q, b_k, allow_tf32=False) + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + # [BT, BV] + p_zp = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,)) + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_o = b_o * tl.exp(b_zp[None, :] - b_z) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_A = tl.where(m_s, b_A, 0.) + if i_v == 0: + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_fwd_kernel_intra_V( + q, + k, + z, + A, + s_k_h, + s_k_t, + s_k_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_q = (b_q * tl.exp(b_zn[None, :] - b_z) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = tl.exp(b_k - b_zn[:, None]).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_q, b_k, allow_tf32=False) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_A = tl.sum(b_q * tl.exp(b_k[None, :] - b_z) * scale, 1) + b_A = tl.where(o_i >= j, b_A, 0.) + tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A) + + p_k = tl.advance(p_k, (K,)) + + +@triton.jit +def chunk_abc_fwd_kernel_V( + q, + v, + z, + h, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_p * K + i_k * BK,), (BK,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + # [BT, BK] + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_q = (b_q * tl.exp(b_zp[None, :] - b_z)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_q, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_bwd_kernel_dh( + q, + z, + do, + dh, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NORMK: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + b_zp = tl.full([BK if NORMK else BV], float('inf'), dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + i_p = tl.maximum(i_t * BT - 1, 0) + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + if NORMK: + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_p * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = tl.exp(b_zc - b_zp), b_zc + # [BK, BT] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_q = (b_q * tl.exp(b_zc[:, None] - b_z)).to(b_q.dtype) + # [BK, BV] + b_dh = b_dh * b_r[:, None] + else: + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = tl.exp(b_zc - b_zp), b_zc + # [BT, BV] + b_z = tl.load(p_z, boundary_check=(0,)) + b_do = (b_do * tl.exp(b_zc[None, :] - b_z)).to(b_do.dtype) + # [BK, BV] + b_dh = b_dh * b_r[None, :] + # [BK, BV] + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + +@triton.jit +def chunk_abc_bwd_kernel_V( + k, + v, + z, + h, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + n_bh = tl.num_programs(2) + + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = tl.exp(b_k - b_zc[None, :]).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), (i_p * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zp = tl.load(p_zp, boundary_check=(0,)) + # [BT, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_z = tl.exp(b_zp[None, :] - b_z) + # [BT, BK] + b_dq = b_dq * b_z + b_dk = b_dk * b_k + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT,), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_bwd_kernel_intra_V( + q, + k, + z, + dA, + dq, + dk, + s_k_h, + s_k_t, + s_k_d, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_zq = tl.exp(b_zn[None, :] - b_z) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kz = tl.exp(b_k - b_zn[None, :]).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kz, allow_tf32=False) + b_dq *= b_zq + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * tl.exp(b_kj[None, :] - b_z), 0.) + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kz = tl.exp(b_k - b_zn[None, :]) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_qz = (b_q * tl.exp(b_zn[None, :] - b_z)).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qz, allow_tf32=False) + b_dk *= b_kz + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_zj = tl.make_block_ptr(z + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_zj = tl.load(p_zj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_k - b_zj[None, :]), 0.) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_bwd_kernel_intra_K( + v, + z, + do, + dA, + s_v_h, + s_v_t, + s_v_d, + scale, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_zn[None, :] - b_z) * scale).to(b_do.dtype) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = tl.exp(b_v - b_zn[:, None]).to(b_v.dtype) + # [BC, BC] + b_dA = tl.dot(b_do, b_v, allow_tf32=False) + tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,)) + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * scale + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_dA = tl.sum(b_do * tl.exp(b_v[None, :] - b_z), 1) + b_dA = tl.where(o_i >= j, b_dA, 0) + tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A) + + p_v = tl.advance(p_v, (V,)) + + +@triton.jit +def chunk_abc_bwd_kernel_K( + q, + k, + v, + z, + h, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + n_bh = tl.num_programs(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False) + b_A = tl.where(m_s, b_A, 0.) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), (i_p * V + i_v * BV,), (BV,), (0,)) + p_zc = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K*V, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BV,] + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = tl.exp(b_v - b_zc[None, :]).to(b_v.dtype) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_z = tl.exp(b_zp[None, :] - b_z) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * b_z * scale).to(b_do.dtype) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = b_v * tl.dot(b_k, b_dh, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BT, BK] + b_dq += tl.dot(b_dA, b_k, allow_tf32=False) + b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False) + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_bwd_kernel_intra_KV( + v, + z, + A, + do, + dv, + s_v_h, + s_v_t, + s_v_d, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_v_h, (T*V,), (s_v_d,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_zn[None, :] - b_z)).to(b_do.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_dv *= tl.exp(b_v - b_zn[None, :]) + + o_i = tl.arange(0, BC) + for j in range(0, BC): + p_z = tl.make_block_ptr(z + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(p_A, boundary_check=(0,)) + # [BV,] + b_z = tl.load(p_z, boundary_check=(0,)) + b_do = tl.load(p_do, boundary_check=(0,)) + # [BC, BV] + m_i = o_i[:, None] <= j + b_dv += tl.where(m_i, tl.exp(b_v - b_z[None, :]) * b_A[:, None] * b_do[None, :], 0.) + p_dv = tl.make_block_ptr(dv + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_abc_bwd_kernel_rcum_inter( + s, + z, + ss, + doo, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + NT: tl.constexpr +): + i_m, i_bh = tl.program_id(0), tl.program_id(1) + + b_sp = tl.zeros([BS,], dtype=tl.float32) + b_zp = tl.full([BS,], float('inf'), dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * s_s_h, (T * S,), (s_s_d,), ((i_t * BT) * S + i_m * BS,), (BS,), (0,)) + p_ss = tl.make_block_ptr(ss + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_doo = tl.make_block_ptr(doo + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + # [BS,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_ss = tl.load(p_ss, boundary_check=(0, 1)) + + b_doo = tl.exp(b_s - b_zp[None, :]) * b_sp[None, :] + tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1)) + # [BS,] + b_sp = b_sp * tl.exp(b_zc - b_zp) + tl.sum(b_ss * tl.exp(b_zc[None, :] - b_z), 0) + b_zp = b_zc + + +@triton.jit +def chunk_abc_bwd_kernel_rcum_intra( + s, + z, + ss, + doo, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + NC: tl.constexpr +): + i_s, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + o_i = tl.arange(0, BC) + m_o = tl.full([BC, BC], 1., dtype=tl.float32) + + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * s_s_h, (T*S,), (s_s_d,), ((i_t * BT + i_i * BC + BC - 1) * S + i_s * BS,), (BS,), (0,)) + p_doo = tl.make_block_ptr(doo + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0)) + # [BC, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)) + # [BS,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + + b_doo = tl.zeros([BC, BS], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0)) + p_ss = tl.make_block_ptr(ss + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0)) + # [BC, BS] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_ss = tl.load(p_ss, boundary_check=(0, 1)) + # [BC, BS] + b_doo += b_ss * tl.exp(b_zn[None, :] - b_z) + b_doo = tl.exp(b_s - b_zn[None, :]) * tl.dot(m_o.to(b_s.dtype), b_doo.to(b_s.dtype), allow_tf32=False) + + for j in range(0, BC): + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T * S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,)) + p_ss = tl.make_block_ptr(ss + i_bh * s_s_h, (T * S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,)) + # [BS,] + b_z = tl.load(p_z, boundary_check=(0,)) + b_ss = tl.load(p_ss, boundary_check=(0,)) + # [BC, BS] + m_i = o_i[:, None] <= j + b_doo += tl.where(m_i, tl.exp(b_s - b_z[None, :]) * b_ss[None, :], 0.) + b_doo += tl.load(p_doo, boundary_check=(0, 1)) + tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkABCFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, s, initial_state, output_final_state): + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_pre(s, B, H, T, S): + # keep cummulative normalizer in fp32 + z = torch.empty_like(s, dtype=torch.float) + grid = (B * H,) + logcumsumexp_fwd_kernel[grid]( + s, z, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=S + ) + return z + + def fwd_inner(q, k, v, z, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_abc_fwd_kernel_h[grid]( + k, v, z, h, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + NORMK=normk, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + final_state = None + if output_final_state: + final_state = (q.new_empty(B, H, K, M, dtype=torch.float), + q.new_empty(B, H, M, V, dtype=torch.float)) + + z = fwd_pre(s, B, H, T, M) + scale = K ** -0.5 + hk = fwd_inner( + q=q, k=k, v=s, z=z, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + normk=False, + h0=initial_state[0] if initial_state is not None else None, + ht=final_state[0] if final_state is not None else None + ) + ok1 = torch.empty_like(s) + Ak = q.new_empty(B, H, T, BT) + grid = (NM, NT, B * H) + chunk_abc_fwd_kernel_K[grid]( + q, k, z, hk, ok1, Ak, + k.stride(1), k.stride(2), k.stride(3), + s.stride(1), s.stride(2), s.stride(3), + hk.stride(1), hk.stride(2), hk.stride(3), + scale=scale, + T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, + num_warps=num_warps, + num_stages=num_stages + ) + ok0 = torch.empty_like(s) + grid = (NM, NT * NC, B * H) + chunk_abc_fwd_kernel_intra_K[grid]( + s, z, ok0, Ak, + s.stride(1), s.stride(2), s.stride(3), + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + ok = ok0.add_(ok1) + + scale = 1. + # equivalent to: + # p = ok.softmax(-1, torch.float) + # p is kept in fp32 for safe softmax backward + p = torch.empty_like(ok, dtype=torch.float) + grid = (NT, B * H) + softmax_fwd_kernel[grid]( + ok, p, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=M, BT=BT + ) + qv = p.to(q.dtype) + + scale = 1. + hv = fwd_inner( + q=qv, k=s, v=v, z=z, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + normk=True, + h0=initial_state[1] if initial_state is not None else None, + ht=final_state[1] if final_state is not None else None + ) + Av = q.new_zeros(NM, B, H, T, BT) + grid = (NM, NT * NC * NC, B * H) + chunk_abc_fwd_kernel_intra_V[grid]( + qv, s, z, Av, + s.stride(1), s.stride(2), s.stride(3), + scale=scale, + T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + Av = Av.sum(0) + ov = torch.empty_like(v) + grid = (NV, NT, B * H) + chunk_abc_fwd_kernel_V[grid]( + qv, v, z, hv, ov, Av, + s.stride(1), s.stride(2), s.stride(3), + v.stride(1), v.stride(2), v.stride(3), + hv.stride(1), hv.stride(2), hv.stride(3), + scale=scale, + T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v, s, z, ok, p, hk, hv, Av) + ctx.BT = BT + return ov, final_state + + @staticmethod + @contiguous + def backward(ctx, dov, dht=None): + q, k, v, s, z, ok, p, hk, hv, Av = ctx.saved_tensors + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK, NM = triton.cdiv(K, BK), triton.cdiv(M, BM) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def bwd_inner(q, z, do, B, H, T, K, V, BT, BK, BV, NT, scale, normk=False): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_abc_bwd_kernel_dh[grid]( + q, z, do, dh, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), dh.stride(3), + scale=scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + NORMK=normk, + num_warps=num_warps, + num_stages=num_stages + ) + return dh + + def bwd_post(s, z, ss, B, H, T, S, BT, BC, BS, NT, NC, NS): + doo = torch.empty_like(s) + grid = (NS, B * H) + chunk_abc_bwd_kernel_rcum_inter[grid]( + s, z, ss, doo, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=S, BT=BT, BS=BS, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NS, NT * NC, B * H) + chunk_abc_bwd_kernel_rcum_intra[grid]( + s, z, ss, doo, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=S, BT=BT, BC=BC, BS=BS, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + return doo + + scale = 1. + qv = p.to(q.dtype) + dhv = bwd_inner( + qv, z, dov, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + scale=scale, + normk=True + ) + dp1 = torch.empty_like(p) + dsv1 = torch.empty_like(s, dtype=torch.float) + dv = v.new_empty(NM, *v.shape) + dAv = q.new_zeros(B, H, T, BT) + grid = (NM, NT, B * H) + chunk_abc_bwd_kernel_V[grid]( + s, v, z, hv, Av, dov, dhv, dp1, dsv1, dv, dAv, + s.stride(1), s.stride(2), s.stride(3), + v.stride(1), v.stride(2), v.stride(3), + hv.stride(1), hv.stride(2), hv.stride(3), + scale=scale, + T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0) + dp0 = torch.empty_like(p) + dsv0 = s.new_zeros(s.shape, dtype=torch.float) + grid = (NM, NT * NC, B * H) + chunk_abc_bwd_kernel_intra_V[grid]( + qv, s, z, dAv, dp0, dsv0, + s.stride(1), s.stride(2), s.stride(3), + T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + dp = dp1.add_(dp0) + dsv = dsv1.add_(dsv0) + + # softmax gradient, equivalent to: + # dok = p * (dp - (p * dp).sum(-1, True)) + dok = torch.empty_like(ok) + grid = (NT, B * H) + softmax_bwd_kernel[grid]( + p, dp, dok, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=M, BT=BT + ) + + scale = K ** -0.5 + dhk = bwd_inner( + q, z, dok, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + scale=scale, + normk=False + ) + dAk = q.new_zeros(NM, B, H, T, BT) + grid = (NM, NT * NC * NC, B * H) + chunk_abc_bwd_kernel_intra_K[grid]( + s, z, dok, dAk, + s.stride(1), s.stride(2), s.stride(3), + scale=scale, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + dAk = dAk.sum(0) + + Ak = q.new_zeros(NK, B, H, T, BT) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dsk1 = s.new_empty(NK, *s.shape, dtype=torch.float) + grid = (NK, NT, B * H) + chunk_abc_bwd_kernel_K[grid]( + q, k, s, z, hk, Ak, dok, dhk, dq, dk, dsk1, dAk, + q.stride(1), q.stride(2), q.stride(3), + s.stride(1), s.stride(2), s.stride(3), + hk.stride(1), hk.stride(2), hk.stride(3), + scale=scale, + T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, + num_warps=num_warps, + num_stages=num_stages + ) + Ak = Ak.sum(0) + dsk1 = dsk1.sum(0) + dsk0 = torch.empty_like(s, dtype=torch.float) + grid = (NM, NT * NC, B * H) + chunk_abc_bwd_kernel_intra_KV[grid]( + s, z, Ak, dok, dsk0, + s.stride(1), s.stride(2), s.stride(3), + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + ds = dsv.add_(dsk1.add_(dsk0)) + ds -= bwd_post(s, z, ok * dok + p * dp, B, H, T, M, BT, BC, BM, NT, NC, NM) + ds = ds.to(s.dtype) + return dq, dk, dv, ds, None, None + + +def chunk_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: Optional[bool] = False +) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: + if initial_state is not None: + initial_state = tuple(i.detach() for i in initial_state) + ov, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state) + return ov, final_state diff --git a/finetune/lora/v6/fla/ops/abc/chunk_gate.py b/finetune/lora/v6/fla/ops/abc/chunk_gate.py new file mode 100644 index 0000000..3cb9801 --- /dev/null +++ b/finetune/lora/v6/fla/ops/abc/chunk_gate.py @@ -0,0 +1,1287 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import (chunk_reversed_cumsum_fwd, softmax_bwd_kernel, + softmax_fwd_kernel) +from fla.utils import contiguous + + +@triton.autotune( + configs=[ + triton.Config({'BS': 16}, num_warps=2), + triton.Config({'BS': 16}, num_warps=4), + triton.Config({'BS': 16}, num_warps=8), + triton.Config({'BS': 32}, num_warps=2), + triton.Config({'BS': 32}, num_warps=4), + triton.Config({'BS': 32}, num_warps=8), + triton.Config({'BS': 64}, num_warps=2), + triton.Config({'BS': 64}, num_warps=4), + triton.Config({'BS': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def chunk_gated_abc_fwd_kernel_cum( + s, + o, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32) + + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_fwd_kernel_h( + k, + v, + g, + h, + h0, + ht, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + GATEK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + if GATEK: + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_h *= tl.exp(b_gn)[:, None] + # [BK, BT] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype) + else: + p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_h *= tl.exp(b_gn)[None, :] + # [BT, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_v = (b_v * tl.exp(b_gn[None, :] - b_g)).to(b_v.dtype) + # [BK, BV] + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_fwd_kernel_intra_K( + v, + g, + o, + A, + s_v_h, + s_v_t, + s_v_d, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BV] + b_o = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(0, i_i): + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_gv = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_vg = (b_v * tl.exp(b_gn[None, :] - b_gv)).to(b_v.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_vg, allow_tf32=False) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_o *= tl.exp(b_g - b_gn[None, :]) + + o_i = tl.arange(0, BC) + o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + p_gv = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(A + o_A + j, mask=m_A, other=0) + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + b_gv = tl.load(p_gv, boundary_check=(0,)).to(tl.float32) + # [BC, BV] + b_vg = b_v[None, :] * tl.exp(b_g - b_gv[None, :]) + # avoid 0 * inf = inf + m_i = o_i[:, None] >= j + b_o += tl.where(m_i, b_A[:, None] * b_vg, 0.) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + + b_o += tl.load(p_o, boundary_check=(0, 1)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_fwd_kernel_K( + q, + k, + h, + g, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_o += tl.dot(b_q, b_h, allow_tf32=False) + # [BT, BT] + b_A += tl.dot(b_q, b_k, allow_tf32=False) + p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_o = b_o * tl.exp(b_g) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_A = tl.where(m_s, b_A, 0.) + if i_v == 0: + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_fwd_kernel_intra_V( + q, + k, + g, + A, + s_k_h, + s_k_t, + s_k_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_qg, b_kg, allow_tf32=False) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1) + b_A = tl.where(o_i >= j, b_A, 0.) + tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A) + + p_k = tl.advance(p_k, (K,)) + p_gk = tl.advance(p_gk, (K,)) + + +@triton.jit +def chunk_gated_abc_fwd_kernel_V( + q, + v, + g, + h, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + # [BT, BK] + b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_dh( + q, + g, + do, + dh, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + GATEK: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + if GATEK: + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_dh *= tl.exp(b_gn)[:, None] + # [BK, BT] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_q = (b_q * tl.exp(b_g)).to(b_q.dtype) + else: + p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_dh *= tl.exp(b_gn)[None, :] + # [BT, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_g)).to(b_do.dtype) + # [BK, BV] + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_V( + k, + v, + h, + g, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BK,] + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + b_dq = b_dq * tl.exp(b_gk) + b_dk = b_dk * b_gn + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_intra_V( + q, + k, + g, + dA, + dq, + dk, + dg, + s_k_h, + s_k_t, + s_k_d, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + OVERWRITE: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg, allow_tf32=False) + b_dq *= tl.exp(b_g - b_gn[None, :]) + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + p_gkj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.) + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False) + b_dk *= tl.exp(b_gn[None, :] - b_gk) + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_gqj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.) + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1)).to(tl.float32) + b_dg = b_q * b_dq - b_k * b_dk + if not OVERWRITE: + b_dg = b_dg + tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_intra_K( + v, + g, + do, + dA, + s_v_h, + s_v_t, + s_v_d, + scale, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) + p_gv = tl.make_block_ptr(g + i_bh * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_vg = (b_v * tl.exp(b_gn[:, None] - b_gv)).to(b_v.dtype) + # [BC, BC] + b_dA = tl.dot(b_do, b_vg, allow_tf32=False) + tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,)) + p_gv = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,)) + p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * scale + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + b_gv = tl.load(p_gv, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_dA = tl.sum(b_do * b_v[None, :] * tl.exp(b_g - b_gv[None, :]), 1) + b_dA = tl.where(o_i >= j, b_dA, 0) + tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A) + + p_v = tl.advance(p_v, (V,)) + p_gv = tl.advance(p_gv, (V,)) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_K( + q, + k, + v, + h, + g, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False) + b_A = tl.where(m_s, b_A, 0.) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K*V, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_v = b_v * tl.exp(b_gn[None, :] - b_g) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_g) * scale).to(b_do.dtype) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v.to(b_dh.dtype), tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = tl.exp(b_gn[None, :] - b_g) * tl.dot(b_k, b_dh, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BT, BK] + b_dq += tl.dot(b_dA, b_k, allow_tf32=False) + b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False) + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gated_abc_bwd_kernel_intra_KV( + v, + g, + o, + A, + do, + dv, + dg, + s_v_h, + s_v_t, + s_v_d, + T: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr, + OVERWRITE: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_gv = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T*V,), (s_v_d,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BV] + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * tl.exp(b_g - b_gn[None, :])).to(b_do.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_dv *= tl.exp(b_gn[None, :] - b_gv) + + o_i = tl.arange(0, BC) + for j in range(0, BC): + p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(p_A, boundary_check=(0,)) + # [BV,] + b_g = tl.load(p_g, boundary_check=(0,)) + b_do = tl.load(p_do, boundary_check=(0,)) + # [BC, BV] + m_i = o_i[:, None] <= j + b_dv += tl.where(m_i, tl.exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32) + b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32) + b_dg = b_o * b_do - b_v * b_dv + if not OVERWRITE: + b_dg = b_dg + tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_pre(g, B, H, T, S, BT): + NT = triton.cdiv(T, BT) + g_org, g = g, torch.empty_like(g, dtype=torch.float) + def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_gated_abc_fwd_kernel_cum[grid]( + g_org, g, + g.stride(1), g.stride(2), g.stride(3), + T=T, S=S, BT=BT + ) + return g + + +def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, gatek=False, h0=None, ht=None): + NT = triton.cdiv(T, BT) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_gated_abc_fwd_kernel_h[grid]( + k, v, g, h, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + GATEK=gatek, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + +def fwd_v(q, k, v, g, B, H, T, K, V, BT, BK, BV, BC, h0=None, ht=None, scale=1.): + NT = triton.cdiv(T, BT) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + NC = triton.cdiv(BT, BC) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + h = fwd_inner( + q=q, k=k, v=v, g=g, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + gatek=True, + h0=h0, + ht=ht + ) + A = q.new_zeros(NK, B, H, T, BT) + grid = (NK, NT * NC * NC, B * H) + chunk_gated_abc_fwd_kernel_intra_V[grid]( + q, k, g, A, + k.stride(1), k.stride(2), k.stride(3), + scale, + T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, + num_warps=2, + num_stages=num_stages + ) + A = A.sum(0, dtype=A.dtype) + o = torch.empty_like(v) + grid = (NV, NT, B * H) + chunk_gated_abc_fwd_kernel_V[grid]( + q, v, g, h, o, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + return o, h, A + + +def fwd_k(q, k, v, g, B, H, T, K, V, BT, BK, BV, BC, h0=None, ht=None, scale=1.): + NT = triton.cdiv(T, BT) + NV = triton.cdiv(V, BV) + NC = triton.cdiv(BT, BC) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + h = fwd_inner( + q=q, k=k, v=v, g=g, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + gatek=False, + h0=h0, + ht=ht + ) + o = torch.empty_like(v) + A = q.new_empty(B, H, T, BT) + grid = (NV, NT, B * H) + chunk_gated_abc_fwd_kernel_K[grid]( + q, k, h, g, o, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NV, NT * NC, B * H) + chunk_gated_abc_fwd_kernel_intra_K[grid]( + v, g, o, A, + v.stride(1), v.stride(2), v.stride(3), + T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, + num_warps=2, + num_stages=num_stages + ) + return o, h, A + + +def bwd_inner(q, g, do, B, H, T, K, V, BT, BK, BV, scale, gatek=False): + NT = triton.cdiv(T, BT) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_gated_abc_bwd_kernel_dh[grid]( + q, g, do, dh, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), dh.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + GATEK=gatek, + num_warps=num_warps, + num_stages=num_stages + ) + return dh + + +def bwd_v(q, k, v, g, h, A, do, dg, B, H, T, K, V, BT, BK, BV, BC, scale=1.): + NT = triton.cdiv(T, BT) + NK = triton.cdiv(K, BK) + NC = triton.cdiv(BT, BC) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + overwrite_dg = dg is None + dh = bwd_inner( + q, g, do, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + scale=scale, + gatek=True + ) + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + dv = v.new_empty(NK, *v.shape) + dg = torch.empty_like(g, dtype=torch.float) if dg is None else dg + dA = v.new_zeros(B, H, T, BT) + + grid = (NK, NT, B * H) + chunk_gated_abc_bwd_kernel_V[grid]( + k, v, h, g, A, do, dh, dq, dk, dv, dA, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0, dtype=dv.dtype) + grid = (NK, NT * NC, B * H) + chunk_gated_abc_bwd_kernel_intra_V[grid]( + q, k, g, dA, dq, dk, dg, + k.stride(1), k.stride(2), k.stride(3), + T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, + OVERWRITE=overwrite_dg, + num_warps=num_warps, + num_stages=num_stages + ) + return dq, dk, dv, dg + + +def bwd_k(q, k, v, g, h, o, do, dg, B, H, T, K, V, BT, BK, BV, BC, scale=1.): + NT = triton.cdiv(T, BT) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + NC = triton.cdiv(BT, BC) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + overwrite_dg = dg is None + dh = bwd_inner( + q, g, do, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + scale=scale, + gatek=False + ) + dA = q.new_zeros(NV, B, H, T, BT) + grid = (NV, NT * NC * NC, B * H) + chunk_gated_abc_bwd_kernel_intra_K[grid]( + v, g, do, dA, + v.stride(1), v.stride(2), v.stride(3), + scale, + T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + dA = dA.sum(0, dtype=dA.dtype) + + A = do.new_zeros(NK, B, H, T, BT) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = v.new_empty(NK, *v.shape) + dg = torch.empty_like(g, dtype=torch.float) if dg is None else dg + grid = (NK, NT, B * H) + chunk_gated_abc_bwd_kernel_K[grid]( + q, k, v, h, g, A, do, dh, dq, dk, dv, dA, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + A = A.sum(0, dtype=A.dtype) + dv = dv.sum(0, dtype=dv.dtype) + grid = (NV, NT * NC, B * H) + chunk_gated_abc_bwd_kernel_intra_KV[grid]( + v, g, o, A, do, dv, dg, + v.stride(1), v.stride(2), v.stride(3), + T=T, V=V, BT=BT, BC=BC, BV=BV, NC=NC, + OVERWRITE=overwrite_dg, + num_warps=num_warps, + num_stages=num_stages + ) + return dq, dk, dv, dg + + +class ChunkGatedABCFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, s, g, scale, initial_state, output_final_state, checkpoint_level): + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + + final_state = None + if output_final_state: + final_state = (q.new_empty(B, H, K, M, dtype=torch.float), + q.new_empty(B, H, M, V, dtype=torch.float)) + + g_org, g = g, fwd_pre(g, B, H, T, M, BT) + ok, hk, _ = fwd_k( + q=q, k=k, v=s, g=g, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, BC=BC, + h0=initial_state[0] if initial_state is not None else None, + ht=final_state[0] if final_state is not None else None, + scale=scale + ) + + # equivalent to: + # p = ok.softmax(-1, torch.float) + # p is kept in fp32 for safe softmax backward + p = torch.empty_like(ok, dtype=torch.float) + def grid(meta): return (triton.cdiv(meta['T'], meta['BT']), B * H) + softmax_fwd_kernel[grid]( + ok, p, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=M, BT=BT + ) + + ov, hv, Av = fwd_v( + q=p.to(q.dtype), k=s, v=v, g=g, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, BC=BC, + h0=initial_state[1] if initial_state is not None else None, + ht=final_state[1] if final_state is not None else None, + scale=1. + ) + + if checkpoint_level >= 1: + del g + g = g_org + if checkpoint_level > 1: + del hk + del hv + hk, hv = None, None + initial_state = tuple() if initial_state is None else initial_state + else: + initial_state = tuple() + + ctx.save_for_backward(q, k, v, s, g, ok, p, hk, hv, Av, *initial_state) + ctx.checkpoint_level = checkpoint_level + ctx.scale = scale + ctx.BT = BT + return ov, final_state + + @staticmethod + @contiguous + def backward(ctx, dov, dht=None): + q, k, v, s, g, ok, p, hk, hv, Av, *initial_state = ctx.saved_tensors + qv = p.to(q.dtype) + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + + if ctx.checkpoint_level >= 1: + g = fwd_pre(g, B, H, T, M, BT) + + # rerun the forward pass to get h if checkpoint_level >= 1 + if ctx.checkpoint_level > 1: + hk = fwd_inner( + q=q, k=k, v=s, g=g, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, + gatek=False, + h0=initial_state[0] if len(initial_state) > 0 else None, + ht=None + ) + hv = fwd_inner( + q=qv, k=s, v=v, g=g, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, + gatek=True, + h0=initial_state[1] if len(initial_state) > 0 else None, + ht=None + ) + + dqv, dsv, dv, dg = bwd_v( + q=qv, k=s, v=v, g=g, h=hv, A=Av, do=dov, dg=None, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, BC=BC, + scale=1. + ) + + # softmax gradient, equivalent to: + # dok = qv * (dqv - (qv * dqv).sum(-1, True)) + dok = torch.empty_like(ok) + def grid(meta): return (triton.cdiv(meta['T'], meta['BT']), B * H) + softmax_bwd_kernel[grid]( + p, dqv, dok, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=M, BT=BT + ) + + dq, dk, dsk, dg = bwd_k( + q=q, k=k, v=s, g=g, h=hk, o=ok, do=dok, dg=dg, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, BC=BC, + scale=ctx.scale + ) + + ds = dsv.add_(dsk) + # reversed cumsum, equivalent to: + # + # def reversed_cumsum(x, dim=-1): + # c = x.cumsum(dim) + # return x + c.index_select(dim, x.new_tensor([c.shape[dim]-1], dtype=torch.long)) - c + dg = chunk_reversed_cumsum_fwd(dg).to(s.dtype) + return dq, dk, dv, ds, dg, None, None, None, None + + +def chunk_gated_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: Optional[bool] = False, + checkpoint_level: Optional[int] = 2 +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + g (torch.Tensor): + Forget gates of shape `(B, H, T, M)` applied to keys. + If not provided, this function is equivalent to vanilla ABC. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[Tuple[torch.Tensor]]): + Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`. + checkpoint_level (Optional[int]): + Checkpointing level; higher values will save more memories and do more recomputations during backward. + Default: `2`: + - Level `0`: no memory saved, no recomputation. + - Level `1`: recompute the fp32 cumulative values during backward. + - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward. + """ + assert checkpoint_level in [0, 1, 2] + if initial_state is not None: + initial_state = tuple(i.detach() for i in initial_state) + if g is None: + # TODO: this 3 steps took huge amount of time, ought to be optimized + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + s = torch.exp(s - z).to(k.dtype) + if scale is None: + scale = q.shape[-1] ** -0.5 + ov, final_state = ChunkGatedABCFunction.apply(q, k, v, s, g, scale, initial_state, output_final_state, checkpoint_level) + return ov, final_state diff --git a/finetune/lora/v6/fla/ops/abc/naive.py b/finetune/lora/v6/fla/ops/abc/naive.py new file mode 100644 index 0000000..5abc3f5 --- /dev/null +++ b/finetune/lora/v6/fla/ops/abc/naive.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = q.dtype + + # [batch_size, n_heads, seq_len, n_slots] + if g is None: + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + s = torch.exp(s - z) + q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g)) + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + + hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) + ok = torch.zeros_like(s) + + if scale is None: + scale = q.shape[-1] ** -0.5 + + final_state = None + if initial_state is not None: + hk += initial_state[0] + + for i in range(T): + q_i = q[:, :, i] * scale + k_i = k[:, :, i] + v_i = s[:, :, i] + g_i = g[:, :, i].exp() + hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] + ok[:, :, i] = (q_i[..., None] * hk).sum(-2) + + qv = ok.softmax(-1) + hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) + ov = torch.zeros_like(v) + if initial_state is not None: + hv += initial_state[1] + + for i in range(T): + q_i = qv[:, :, i] + k_i = s[:, :, i] + v_i = v[:, :, i] + g_i = g[:, :, i].exp() + hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] + ov[:, :, i] = (q_i[..., None] * hv).sum(-2) + + if output_final_state: + final_state = (hk, hv) + return ov.to(dtype), final_state + + +def naive_cumsum_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor +) -> torch.Tensor: + """ + A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper. + This is just for demonstration purposes, with no numerical stabilities guaranteed. + """ + + dtype = q.dtype + q, k, v, s = map(lambda x: x.float(), (q, k, v, s)) + + scale = q.shape[-1] ** -0.5 + # [batch_size, n_heads, seq_len, n_slots] + s = (s - s.max(2, True)[0]).exp() + z = s.cumsum(2) + # [batch_size, n_heads, seq_len, n_slots, d_head] + K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) + V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) + # [batch_size, n_heads, seq_len, n_slots] + p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1) + # [batch_size, n_heads, seq_len, d_head] + o = torch.einsum('...m,...md->...d', p, V) + return o.to(dtype), None diff --git a/finetune/lora/v6/fla/ops/abc/recurrent_fuse.py b/finetune/lora/v6/fla/ops/abc/recurrent_fuse.py new file mode 100644 index 0000000..3b4491d --- /dev/null +++ b/finetune/lora/v6/fla/ops/abc/recurrent_fuse.py @@ -0,0 +1,388 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2024, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + + +@triton.jit +def fused_recurrent_gated_abc_fwd_kernel( + q, + k, + v, + gk, + gv, + o, + h0, + ht, + s_k_h, + s_v_h, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + REVERSE: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + + if USE_GK: + p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + + h = tl.zeros([BV, BK], dtype=tl.float32) + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + h = h * b_gk[None, :] + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + h = h * b_gv[:, None] + h += b_k[None, :] * b_v[:, None] + b_o = h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv) + p_q += -K if REVERSE else K + p_k += -K if REVERSE else K + p_o += -V if REVERSE else V + p_v += -V if REVERSE else V + if USE_GK: + p_gk += -K if REVERSE else K + if USE_GV: + p_gv += -V if REVERSE else V + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +@triton.jit +def fused_recurrent_gated_abc_bwd_kernel( + q, + k, + v, + gk, + gv, + do, + dq, + dk, + dv, + h0, + s_k_h, + s_v_h, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + REVERSE: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + mask_kv = mask_bk[:, None] & mask_bv[None, :] + h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + h = h * b_gk[:, None] + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + h = h * b_gv[None, :] + h += b_k[:, None] * b_v[None, :] + b_dq = tl.sum(h * b_do[None, :], axis=1) * scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += -K if REVERSE else K + p_v += -V if REVERSE else V + p_q += -K if REVERSE else K + p_do += -V if REVERSE else V + p_dq += -K if REVERSE else K + if USE_GK: + p_gk += -K if REVERSE else K + if USE_GV: + p_gv += -V if REVERSE else V + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for _ in range(T): + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + b_dh += b_q[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * b_v[None, :], axis=1) + b_dv = tl.sum(b_dh * b_k[:, None], axis=0) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + b_dh *= b_gk[:, None] + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + b_dh *= b_gv[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_q += K if REVERSE else -K + p_k += K if REVERSE else -K + p_v += V if REVERSE else -V + p_do += V if REVERSE else -V + p_dk += K if REVERSE else -K + p_dv += V if REVERSE else -V + if USE_GK: + p_gk += K if REVERSE else -K + if USE_GV: + p_gv += V if REVERSE else -V + + +class FusedRecurrentGatedABCFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, s, g, scale=None, initial_state=None, output_final_state=False, reverse=False): + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + # default scale + if scale is None: + scale = K ** -0.5 + + BK, BV, BM = min(K, 32), min(V, 32), min(M, 32) + NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) + num_stages = 1 + num_warps = 1 + + g = g.float().exp() + + final_state = (None, None) + if output_final_state: + final_state = (q.new_empty(B, H, K, M), q.new_empty(B, H, M, V)) + + ok = q.new_empty(NK, B, H, T, M, dtype=torch.float) + gk, gv = None, g + grid = (NM, NK, B * H) + fused_recurrent_gated_abc_fwd_kernel[grid]( + q, k, s, gk, gv, ok, initial_state[0], final_state[0], + k.stride(1), + s.stride(1), + scale=scale, + B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM, + USE_INITIAL_STATE=initial_state[0] is not None, + STORE_FINAL_STATE=final_state[0] is not None, + USE_GK=False, + USE_GV=True, + REVERSE=reverse, + num_warps=num_warps, + num_stages=num_stages + ) + ok = ok.sum(0) + + qv = ok.softmax(-1, dtype=torch.float) + ov = q.new_empty(NM, B, H, T, V, dtype=torch.float) + gk, gv = g, None + grid = (NV, NM, B * H) + fused_recurrent_gated_abc_fwd_kernel[grid]( + qv, s, v, gk, gv, ov, initial_state[1], final_state[1], + s.stride(1), + v.stride(1), + scale=1., + B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV, + USE_INITIAL_STATE=initial_state[0] is not None, + STORE_FINAL_STATE=final_state[0] is not None, + USE_GK=True, + USE_GV=False, + REVERSE=reverse, + num_warps=num_warps, + num_stages=num_stages + ) + ov = ov.sum(0) + + ctx.save_for_backward(q, k, v, s, g, qv, *initial_state, ok) + ctx.scale = scale + ctx.reverse = reverse + # we do not need the gradient of the final state from the next chunk + # similiar to Trunctated BPTT + if final_state is not None: + final_state = tuple(i.detach() for i in final_state) + return ov.to(q.dtype), final_state + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do, dht=None): + q, k, v, s, g, qv, *initial_state, ok = ctx.saved_tensors + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + V = v.shape[-1] + scale = ctx.scale + + BK, BV, BM = min(K, 32), min(V, 32), min(M, 32) + NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) + num_stages = 1 + num_warps = 1 + + dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float) + dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float) + dv = q.new_empty(NM, B, H, T, V, dtype=torch.float) + gk, gv = g, None + grid = (NV, NM, B * H) + fused_recurrent_gated_abc_bwd_kernel[grid]( + qv, s, v, gk, gv, do, dqv, dsv, dv, initial_state[1], + s.stride(1), + v.stride(1), + scale=1., + B=B, H=H, T=T, K=M, V=V, BK=BM, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state[1] is not None, + REVERSE=ctx.reverse, + USE_GK=gk is not None, + USE_GV=gv is not None + ) + dqv = dqv.sum(0) + dsv = dsv.sum(0) + dv = dv.sum(0) + dgk = dqv * qv.float() - dsv * s.float() + dgk_cumsum = dgk.cumsum(-2) + dgk = dgk + dgk_cumsum[:, :, -1, None] - dgk_cumsum + + dok = qv * (dqv - (qv * dqv).sum(-1, True)) + dq = q.new_empty(NM, B, H, T, K, dtype=torch.float) + dk = q.new_empty(NM, B, H, T, K, dtype=torch.float) + dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float) + gk, gv = None, g + grid = (NM, NK, B * H) + fused_recurrent_gated_abc_bwd_kernel[grid]( + q, k, s, gk, gv, dok, dq, dk, dsk, initial_state[0], + q.stride(1), + s.stride(1), + scale=scale, + B=B, H=H, T=T, K=K, V=M, BK=BK, BV=BM, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state[0] is not None, + REVERSE=ctx.reverse, + USE_GK=gk is not None, + USE_GV=gv is not None + ) + dq = dq.sum(0) + dk = dk.sum(0) + dsk = dsk.sum(0) + + dgv = dok.float() * ok.float() - dsk * s.float() + dgv_cumsum = dgv.cumsum(-2) + dgv = dgv + dgv_cumsum[:, :, -1, None] - dgv_cumsum + + ds = dsk.add_(dsv) + dg = dgk.add_(dgv) + + return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, None, None, None + + +def fused_recurrent_gated_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: Optional[bool] = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + g (torch.Tensor): + Forget gates of shape `(B, H, T, M)` applied to keys. + If not provided, this function is equivalent to vanilla ABC. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[Tuple[torch.Tensor]]): + Initial state tuple having tensors of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state tuple, having tensors of shape `(B, H, K, V)`. Default: `False`. + """ + if initial_state is not None: + initial_state = tuple(i.detach() for i in initial_state) + if g is None: + # TODO: this 3 steps took huge amount of time, ought to be optimized + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + s = torch.exp(s - z).to(k.dtype) + if scale is None: + scale = q.shape[-1] ** -0.5 + ov, final_state = FusedRecurrentGatedABCFunction.apply(q, k, v, s, g, scale, initial_state, output_final_state) + return ov, final_state diff --git a/finetune/lora/v6/fla/ops/based/__init__.py b/finetune/lora/v6/fla/ops/based/__init__.py new file mode 100644 index 0000000..5bcfcdc --- /dev/null +++ b/finetune/lora/v6/fla/ops/based/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk_fuse import fused_chunk_based +from .parallel import parallel_based + +__all__ = [ + 'fused_chunk_based', + 'parallel_based' +] diff --git a/finetune/lora/v6/fla/ops/based/chunk_fuse.py b/finetune/lora/v6/fla/ops/based/chunk_fuse.py new file mode 100644 index 0000000..2f10405 --- /dev/null +++ b/finetune/lora/v6/fla/ops/based/chunk_fuse.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_chunk_based_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + z, # normalizer [B, H, L, 1] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_0o = 0 + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK*BK, BT] + b_k_2o = b_k[:, None, :] * b_k[None, :, :] + b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype) + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_z = tl.zeros([BT], dtype=tl.float32) + + # interchunk + # zero-order + b_o += b_h_0o + b_z += k_0o + # first-order + b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False) + b_z += tl.sum(b_q * k_1o, axis=1) + # second-order + b_q_2o = b_q[:, :, None] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype) + b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5 + b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5 + + # update running statistics + k_1o += tl.sum(b_k, axis=1)[None, :] + k_2o += tl.sum(b_k_2o, axis=1)[None, :] + k_0o += BT + + # intrachunk + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + # [TB, BV] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), + mask=(i * BT + tl.arange(0, BT)) < T) + + # update hidden state + # [BK, BV] + b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False) + b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False) + b_h_0o = b_h_0o + tl.sum(b_v, axis=0) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_z += BT + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_based_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + do, # gradient of output [B, H, L, D_head_V] + dz, # gradient of normalizer [B, H, L] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + # b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BV, BK], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32) + + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr( + k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr( + v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, + (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + # load tensors + # [BT, BK] + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # inter-chunk + b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False) + if i_v == 0: + b_dq += b_dz[:, None] * k_1o + b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5 + if i_v == 0: + b_dq_2o += (b_dz[:, None] * k_2o) * 0.5 + b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK]) + b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1) + b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2) + b_dq *= scale + + # intra-chunk + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False) + + # store + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # update hidden state + # [BT, BK*BK] + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + # [BV, BK*BK] + b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False) + # [BV, BK] + b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False) + + if i_v == 0: + # update running statistics + k_1o += tl.sum(b_k, axis=0)[None, :] + k_2o += tl.sum(b_k_2o, axis=0)[None, :] + + tl.debug_barrier() + b_h_1o = None + b_h_2o = None + + # [BK, BV], first-order taylor expansion + b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + b_dh_0o = tl.zeros([BV], dtype=tl.float32) + m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :] + + dq_1o = tl.zeros([1, BK], dtype=tl.float32) + dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32) + + for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr( + k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr( + v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i, i_v*BV), (BT, BV), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_q = (b_q * scale).to(b_k.dtype) + + # intra chunk + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + b_ds = tl.where(m_s, b_ds, 0) + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + b_ds *= (1+b_s) + + b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False) + b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False) + + # inter chunk + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + + b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False) + b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False) + b_dv += b_dh_0o + + b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False) + + if i_v == 0: + b_dk += dq_1o + + b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), + tl.trans(b_v), allow_tf32=False) + if i_v == 0: + b_dk_2o += dq_2o + b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT]) + b_k_fp32 = tl.trans(b_k.to(tl.float32)) + b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0) + b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1) + b_dk += tl.trans(b_dk2) + + # hidden state update + b_dh_0o += tl.sum(b_do, axis=0) + b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False) + b_q_2o = b_q[None, :, :] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype) + b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5 + + if i_v == 0: + dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :] + dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None] + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkBasedFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, scale=1): + batch_size, n_heads, seq_len, d_head_qk = q.shape + # assert d_head_qk == 16, "currently we do not support feature dim other than 16" + d_head_v = v.shape[-1] + + scale = scale + BT = 16 + BK, BV = min(d_head_qk, 16), min(d_head_v, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + + num_warps = 4 + + # the norm of o might explode, so we need to use float32 here + o = q.new_empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=torch.float32) + z = q.new_empty(NK, batch_size, n_heads, seq_len, dtype=torch.float32) + + grid = (NV, NK, batch_size * n_heads) + fused_chunk_based_fwd_kernel[grid]( + q, k, v, o, z, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + ) + o = o.sum(0) + z = z.sum(0) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.to(q.dtype), z.to(z.dtype) + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = ctx.scale + + BT = 16 + BK, BV = min(d_head_qk, 16), min(d_head_v, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + + fused_chunk_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None + + +triton_fused_chunk_based = FusedChunkBasedFunction.apply + + +def fused_chunk_based(q, k, v, use_scale=True, use_normalize=True): + assert q.shape[-1] <= 16, 'only support feature dimension up to 16.' + if use_scale: + scale = q.shape[-1] ** -0.5 + else: + scale = 1 + o, z = triton_fused_chunk_based(q, k, v, scale) + if use_normalize: + o = o / (z[..., None] + 1e-6) + else: + o = o + + return o.to(q.dtype) diff --git a/finetune/lora/v6/fla/ops/based/naive.py b/finetune/lora/v6/fla/ops/based/naive.py new file mode 100644 index 0000000..fbfabbb --- /dev/null +++ b/finetune/lora/v6/fla/ops/based/naive.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + +from fla.ops.based.chunk_fuse import fused_chunk_based +from fla.ops.based.parallel import parallel_based + + +def naive_parallel_based(q, k, v, use_scale=True, use_norm=True): + if use_scale: + q = q * (q.shape[-1] ** -0.5) + attn = q @ k.transpose(-2, -1) + attn = 1 + attn + 1/2 * (attn ** 2) + attn.masked_fill_(~torch.tril(torch.ones( + q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) + o = attn @ v + if use_norm: + z = attn.sum(-1) + return o / (z[..., None] + 1e-6) + else: + return o + + +def naive_chunk_based(q, k, v, chunk_size=256): + q = q * (q.shape[-1] ** -0.5) + + # compute normalizer. + k_cumsum = torch.cumsum(k, dim=-2) + kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3) + # first + z = (q * k_cumsum).sum(-1) + # second order + z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5 + # zero-th order + z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :] + + # compute o + # constant term + _o = v.cumsum(-2) + + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) + + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + + intra_chunk_attn = q @ k.transpose(-2, -1) + intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2) + intra_chunk_attn.masked_fill_( + ~torch.tril( + torch.ones(chunk_size, chunk_size, + dtype=torch.bool, device=q.device), + ), 0) + o = intra_chunk_attn @ v + + # quadractic term + kv = torch.einsum( + 'b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + + o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q) + + # linear term + kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q) + + o = rearrange(o, 'b h n c d -> b h (n c) d') + o = o + _o + return o / (z[..., None] + 1e-6) + + +if __name__ == "__main__": + B = 4 + H = 4 + L = 128 + # D = 15 + dtype = torch.float32 + q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) + k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) + v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True) + + do = torch.randn_like(v).cuda() + ref = naive_parallel_based(q, k, v, True, True) + ref.backward(do, retain_graph=True) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + + # tri = naive_chunk_based(q, k, v) + # tri.backward(do, retain_graph=True) + # tri_dq, q.grad = q.grad.clone(), None + # tri_dk, k.grad = k.grad.clone(), None + # tri_dv, v.grad = v.grad.clone(), None + + # assert ref.allclose(tri, 0, 1e-4), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() + + tri = fused_chunk_based(q, k, v, True, True) + tri.backward(do, retain_graph=True) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + print((ref-tri).abs().max()) + print((ref_dq-tri_dq).abs().max()) + print((ref_dk-tri_dk).abs().max()) + print((ref_dv-tri_dv).abs().max()) + + # assert ref.allclose(tri, 0, 1e-4), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() + + tri = parallel_based(q, k, v, True, True) + tri.backward(do, retain_graph=True) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + + print((ref-tri).abs().max()) + print((ref_dq-tri_dq).abs().max()) + print((ref_dk-tri_dk).abs().max()) + print((ref_dv-tri_dv).abs().max()) + + # assert ref.allclose(tri, 0, 1e-4), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() diff --git a/finetune/lora/v6/fla/ops/based/parallel.py b/finetune/lora/v6/fla/ops/based/parallel.py new file mode 100644 index 0000000..f4e3fad --- /dev/null +++ b/finetune/lora/v6/fla/ops/based/parallel.py @@ -0,0 +1,388 @@ + +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + +# Based: An Educational and Effective Sequence Mixer +# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based + + +@triton.jit +def parallel_based_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + z, # normalizer [B, H, L] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q + BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), + mask=((i_c * BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit +def _parallel_based_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_q = (b_q * scale).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), + b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_based_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( + p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( + [BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \ + scale # [BTL, BTS] + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, + (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, + (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def parallel_based_bwd_kernel( + q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_based_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV + ) + tl.debug_barrier() + _parallel_based_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV + ) + + +class ParallelBasedFunction(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, device=q.device) + z = torch.empty(NK, batch_size, n_heads, seq_len, + device=q.device) + parallel_based_fwd_kernel[grid]( + q, k, v, o, z, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=q.dtype, device=q.device) + + parallel_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +triton_parallel_based = ParallelBasedFunction.apply + + +def parallel_based(q, k, v, use_scale=True, use_normalize=True, return_both=False): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if use_scale: + scale = q.shape[-1] ** -0.5 + else: + scale = 1 + o, z = triton_parallel_based(q, k, v, scale) + if return_both: + return o, z + if use_normalize: + o = o / (z[..., None] + 1e-6) + else: + o = o + return o.to(q.dtype) diff --git a/finetune/lora/v6/fla/ops/delta_rule/README.md b/finetune/lora/v6/fla/ops/delta_rule/README.md new file mode 100644 index 0000000..1ab2d48 --- /dev/null +++ b/finetune/lora/v6/fla/ops/delta_rule/README.md @@ -0,0 +1,4 @@ +- Delta Rule + +The implementation of delta rule described in https://arxiv.org/abs/2102.11174 + diff --git a/finetune/lora/v6/fla/ops/delta_rule/__init__.py b/finetune/lora/v6/fla/ops/delta_rule/__init__.py new file mode 100644 index 0000000..b0848b3 --- /dev/null +++ b/finetune/lora/v6/fla/ops/delta_rule/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk_fuse import fused_chunk_delta_rule +from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule +from .chunk import chunk_delta_rule + +__all__ = [ + 'fused_chunk_delta_rule', + 'fused_recurrent_linear_attn_delta_rule', + 'chunk_delta_rule' +] diff --git a/finetune/lora/v6/fla/ops/delta_rule/chunk.py b/finetune/lora/v6/fla/ops/delta_rule/chunk.py new file mode 100644 index 0000000..4f1d0cd --- /dev/null +++ b/finetune/lora/v6/fla/ops/delta_rule/chunk.py @@ -0,0 +1,544 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +import torch +import triton +import triton.language as tl +from fla.ops.utils import contiguous +from torch.cuda.amp import custom_bwd, custom_fwd +from fla.ops.delta_rule.wy_fast import fwd_recompute_w_u, fwd_prepare_wy_repr, bwd_prepare_wy_repr +from fla.ops.delta_rule.chunk_fuse import fused_chunk_delta_rule_fwd, fused_chunk_delta_rule_bwd +# from fla.ops.delta_rule.utils import bwd_prepare_wy_repr + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A , 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_prepare_dv(q, k, do, BT): + dv = torch.empty_like(do) + B, H, T, K, V = *k.shape, do.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H)]( + q, k, do, dv, + k.stride(1), k.stride(2), k.stride(3), + do.stride(1), do.stride(2), do.stride(3), + T, K, V, K**-0.5, BT, BK, BV + ) + return dv + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v, + d, + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32) + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(BT, BC)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BK] + b_d = tl.load(p_d, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False) + # [BK, BV] + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + b_h += b_h_cumsum + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + # b_s = tl.dot(b_k, b_q, allow_tf32=False) + # b_s = tl.where(m_s, b_s, 0) + # b_dv = tl.dot(b_s.to(b_do.dtype), b_do, allow_tf32=False) + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + p_dv2 = tl.make_block_ptr(dv2 + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BV] + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d, b_dv.to(b_q.dtype), allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + o_i = tl.arange(0, BT) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False) + + # [BT, BT] + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype) + # [BT, BK] + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + + +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape, u.shape[-1] + + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty_like(u) + chunk_delta_rule_fwd_kernel_h[grid]( + k, u, w, v_new, h, initial_state, final_state, + k.stride(1), k.stride(2), k.stride(3), + u.stride(1), u.stride(2), u.stride(3), + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + + +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B, H, T, K, V = *q.shape, do.shape[-1] + + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B, H, NT * K, V) + # dv_new = torch.empty_like(do) + grid = (NK, NV, B * H) + dv2 = torch.empty_like(dv) + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT, + ) + return dh, dv2 + + +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + + BK = triton.next_power_of_2(K) + o = torch.empty_like(v_new) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(K), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H) + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + h.stride(1), h.stride(2), + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + ) + return o + + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + + BK = triton.next_power_of_2(K) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dw = torch.empty_like(w) + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + dh.stride(1), dh.stride(2), + scale = K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @custom_fwd + @contiguous + def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1): + ### obtain WY representation. u is actually the new v. + w, u, A = fwd_prepare_wy_repr(k, v, beta, BT) + # ### forward_h + final_state = None + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False) + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state) + ## obtain output + o = chunk_fwd_o_fn(q, k, v_new, h, BT) + # save memory + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, d_ht=None): + q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors + scale = q.shape[-1] ** -0.5 + BT = ctx.BT + w, u = fwd_recompute_w_u(k, v, beta, A, BT) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + dv = fwd_prepare_dv(q, k, do, BT) + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT) + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, dv, BT) + dk.add_(dk2) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + +def chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/delta_rule/chunk_fuse.py b/finetune/lora/v6/fla/ops/delta_rule/chunk_fuse.py new file mode 100644 index 0000000..efb36fd --- /dev/null +++ b/finetune/lora/v6/fla/ops/delta_rule/chunk_fuse.py @@ -0,0 +1,419 @@ +# -*- coding: utf-8 -*- + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from packaging import version +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr +from fla.utils import contiguous + + +# on-the-fly computation without materializing hidden statets into HBMs +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8) + ], + key=["BT", "BK"], +) +@triton.jit +def fused_chunk_delta_rule_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_K] + v, # value [B, H, L, D_head_V] + v_new, + d, # decay [B, H, L, D_head_K] + o, # output [B, H, L, D_head_V] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False) + b_v = b_v - b_v_prime + tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_v_new = tl.advance(p_v_new, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_d = tl.advance(p_d, (BT, 0)) + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fused_chunk_delta_rule_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + d, # decay [B, H, L, D_head_K] + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + dd, # gradient of decay [NV, B, H, L, D_head_K] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + # first reverse + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, DK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, DV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + b_d = tl.load(p_d, boundary_check=(0, 1)) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + + tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + NT = tl.cdiv(T, BT) + for i in range(0, NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, DK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [DV, DK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + if i < (NT - 1): + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False) + p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + ((i+1) * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1)) + + +def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BT = BT + # ctx.BT = BT + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1, 'NK should be 1' + o = q.new_empty(batch_size, n_heads, seq_len, d_head_v) + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + else: + final_state = None + CHECK = True + # if version.parse(triton.__version__) < version.parse('2.2.0'): + # import warnings + # warnings.warn( + # "Triton<2.2.0 detected for running this kernel, " + # "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + # "that lead to significant precision loss. " + # "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + # "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + # ) + # CHECK = True + grid = (NV, NK, batch_size * n_heads) + v_new = torch.empty_like(v) + fused_chunk_delta_rule_fwd_kernel[grid]( + q, k, v, v_new, d, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + ) + return o, v_new, CHECK, final_state + + +def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1 + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + fused_chunk_delta_rule_bwd_kernel[grid]( + q, k, v, d, do, dq, dk, dv, dd, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=CHECK, + # num_warps=num_warps, + # num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dd = dd.sum(0) + dd[:, :, 0:BT] = 0 + return dq, dk, dv, dd + +class FusedChunkDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0): + # lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory. + assert checkpoint_level in [0, 1] + k_origin = k + # k = _l2_norm_fwd(k_origin) + k = k + d, v_new = fwd_prepare_wy_repr(k, v, beta, BT) + o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state) + if checkpoint_level == 1: + d, v_new = None, None + ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state) + ctx.CHECK = CHECK + ctx.chunk_size = BT + return o.to(q.dtype), final_state + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, d_final_state=None): + q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors + chunk_size = ctx.chunk_size + k = k_origin + # k = _l2_norm_fwd(k_origin) + if d is None: + d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state) + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size) + dk.add_(dk2) + # dk = _l2_norm_bwd(k_origin, dk) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None + + +def fused_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state) + return o, final_state + + +def delta_rule_recurrence(q, k, v, beta): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + k = torch.nn.functional.normalize(k, p=2, dim=-1) + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i[..., None] + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ == "__main__": + import torch.nn.functional as F + seq_len = 128 + b = 2 + h = 4 + q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1) + beta = torch.rand(b, h, seq_len).sigmoid() + q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta)) + do = torch.rand_like(v) + o2 = delta_rule_recurrence(q, k, v.clone(), beta) + o2.backward(do, retain_graph=True) + q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad + q.grad = k.grad = v.grad = beta.grad = None + o, _ = fused_chunk_delta_rule(q, k, v, beta, 32) + o.backward(do, retain_graph=True) + q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad + q.grad = k.grad = v.grad = beta.grad = None + print((o - o2).abs().max()) + print((q_grad - q_grad2).abs().max()) + print((k_grad - k_grad2).abs().max()) + print((v_grad - v_grad2).abs().max()) + print((beta_grad - beta_grad2).abs().max()) diff --git a/finetune/lora/v6/fla/ops/delta_rule/naive.py b/finetune/lora/v6/fla/ops/delta_rule/naive.py new file mode 100644 index 0000000..45ca247 --- /dev/null +++ b/finetune/lora/v6/fla/ops/delta_rule/naive.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def delta_rule_recurrence(q, k, v, beta): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i[..., None] + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +def delta_rule_chunkwise(q, k, v, beta, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v * beta[..., None] + k_beta = k * beta[..., None] + + assert l % chunk_size == 0 + + # note that diagonal is masked. + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta]) + attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) + + for i in range(1, chunk_size): + attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) + + attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) + # u + k_cumsum = attn @ v + # w + k_cumdecay = attn @ k_beta + + v = k_cumsum + S = k.new_zeros(b, h, d_k, d_v) + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) + v_prime = k_cumdecay[:, :, i] @ S + v_new = v_i - v_prime + o_inter = q_i @ S + o[:, :, i] = o_inter + attn @ v_new + # chunk state update + S = S + k_i.transpose(-1, -2) @ v_new + + return rearrange(o, 'b h n c d -> b h (n c) d') + + +if __name__ == '__main__': + B = 2 + H = 4 + L = 256 + DK = 128 + DV = 128 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + + o = delta_rule_recurrence(q, k, v, beta) + do = torch.randn(B, H, L, DV).cuda() + o.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + + o2 = delta_rule_chunkwise(q, k, v, beta) + o2.backward(do) + assert torch.allclose(o, o2, atol=1e-4), breakpoint() + assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint() + assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint() + assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint() + assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint() + print("All passed!") diff --git a/finetune/lora/v6/fla/ops/delta_rule/recurrent_fuse.py b/finetune/lora/v6/fla/ops/delta_rule/recurrent_fuse.py new file mode 100644 index 0000000..6bd2426 --- /dev/null +++ b/finetune/lora/v6/fla/ops/delta_rule/recurrent_fuse.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V]. + beta, # beta [B, H, L] + o, # output [B, H, L, D_head_V] + initial_state, + final_state, # final hidden state [B, H, D_head_K, D_head_V] + + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state +): + + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_beta = beta + i_bh * T + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < DK + mask_bv = (i_v * BV + tl.arange(0, BV)) < DV + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _v_minus = tl.sum(h * _k[None, :], axis=1) + _v -= _v_minus + _beta = tl.load(p_beta).to(tl.float32) + # in-place overwrite + tl.store(p_v, _v.to(p_v.dtype.element_ty), mask=mask_bv) + _v *= _beta + h += _k[None, :] * _v[:, None] + _o = h * _q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += DK + p_k += DK + p_o += DV + p_v += DV + p_beta += 1 + + if STORE_FINAL_STATE: + p_final_s = final_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + beta, # beta [B, H, L] + + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + dbeta, # gradient of beta [B, H, L] + + # initial hidden state initialization [B, H, D_head_K, D_head_V] + initial_state, + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + mask_bk = i_k * BK + tl.arange(0, BK) < DK + mask_bv = i_v * BV + tl.arange(0, BV) < DV + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV + p_beta = beta + i_bh * T + T - 1 + p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1 + + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \ + BK + tl.arange(0, BK) + (T - 1) * DK + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \ + BV + tl.arange(0, BV) + (T - 1) * DV + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _beta = tl.load(p_beta).to(tl.float32) + d_h += _q[:, None] * _do[None, :] + d_k = tl.sum(d_h * _v[None, :] * _beta, axis=1) + d_v = tl.sum(d_h * _k[:, None], axis=0) + + d_beta = tl.sum(d_v * _v) + d_v = d_v * _beta + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty)) + + d_h -= _k[:, None] * d_v[None, :] + + p_do -= DV + p_q -= DK + p_k -= DK + p_v -= DV + p_dk -= DK + p_dv -= DV + p_dbeta -= 1 + p_beta -= 1 + + tl.debug_barrier() + + h = tl.zeros([BK, BV], dtype=tl.float32) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_beta = beta + i_bh * T + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + DV + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + DK + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[:, None]) * \ + DV + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + _beta = tl.load(p_beta).to(tl.float32) + _v *= _beta + + h += _k[:, None] * _v[None, :] + _d_q = h * _do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + if i < T - 1: + d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32) + d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32) + d_k -= tl.sum(d_v[None, :] * h, axis=1) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + + p_k += DK + p_do += DV + p_v += DV + p_dk += DK + p_dv += DV + p_dq += DK + p_beta += 1 + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, beta, initial_state=None, output_final_state=False): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + + scale = d_head_qk ** -0.5 + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 8) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + assert NK == 1, "NK > 1 is not supported yet" + o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v) + else: + final_state = None + + grid = (NV, NK, batch_size * n_heads) + fused_recurrent_fwd_kernel[grid]( + q, k, v, beta, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None + ) + o = o.sum(0) + ctx.save_for_backward(q, k, v, beta, initial_state) + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, d_final_state=None): + q, k, v, beta, initial_state = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 1 + num_warps = 2 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + dbeta = q.new_empty(NV, batch_size, n_heads, seq_len) + + fused_recurrent_bwd_kernel[grid]( + q, k, v, beta, do, dq, dk, dv, dbeta, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dbeta = dbeta.sum(0) + return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None + + +def fused_recurrent_linear_attn_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, initial_state, output_final_state) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/delta_rule/utils.py b/finetune/lora/v6/fla/ops/delta_rule/utils.py new file mode 100644 index 0000000..92eafdd --- /dev/null +++ b/finetune/lora/v6/fla/ops/delta_rule/utils.py @@ -0,0 +1,297 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from einops import rearrange +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous +from fla.ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2 + + + +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + o, + o2, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = tl.arange(0, BK) < K + mask_bv = tl.arange(0, BV) < V + mask_bk = mask_bk[None, :] & mask_bt[:, None] + mask_bv = mask_bv[None, :] & mask_bt[:, None] + # [BT, BK] + b_k = tl.load(p_k, mask=mask_bk, other=0) + # [BT,] + b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32) + # [BT, BV] + b_v = tl.load(p_v, mask=mask_bv, other=0) + b_v = (b_v * b_beta[:, None]).to(b_v.dtype) + # [BT, BK] + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + # [BT, BT] + b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + + for i in range(BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + b_A = b_A.to(b_k.dtype) + b_w = tl.dot(b_A, b_kb, allow_tf32=False) + b_u = tl.dot(b_A, b_v, allow_tf32=False) + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk) + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta, + o, o2, do, do2, + dk, dv, dbeta, + NT, K, V, T, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None] + mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None] + b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt) + + b_beta = b_beta.to(tl.float32) + A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None] + A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0) + b_do = tl.load(p_do, mask=mask_bk).to(tl.float32) + b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32) + dA = tl.zeros([BT, BT], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i in range(BT-1, -1, -1): + mask = tl.arange(0, BT) == i + attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0) + do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0) + dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0) + b_do = b_do - attn[:, None] * do_[None, :] + b_dv = b_dv - attn[:, None] * dv_[None, :] + tl.debug_barrier() + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_v = tl.load(p_v, mask=mask_bv) + b_dk += b_do * b_beta[:, None] + b_dbeta = tl.sum(b_do * b_k, axis=1) + b_dbeta += tl.sum(b_dv * b_v, axis=1) + b_v = None + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_o = tl.load(p_o, mask=mask_bk) + b_o2 = tl.load(p_o2, mask=mask_bv) + + dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False) + dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype), + allow_tf32=False) + dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0) + b_dv *= b_beta[:, None] + p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv) + + b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1) + dA = dA * b_beta[:, None] + b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False) + b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False) + p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk) + p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt) + + +def fwd_prepare_wy_repr(k, v, beta, chunk_size): + B, H, T, K, V = *k.shape, v.shape[-1] + v_new = torch.empty_like(v) + o_cumdecay = torch.empty_like(k) + BT = chunk_size + NT = triton.cdiv(T, BT) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, o_cumdecay, v_new, + T, K, V, BT, BK, BV + ) + return o_cumdecay, v_new + + +def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size): + b, h, l, d_k = do.shape + d_v = v.shape[-1] + BK = triton.next_power_of_2(d_k) + BV = triton.next_power_of_2(d_v) + c = chunk_size + BK = d_k + NT = triton.cdiv(l, c) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.zeros_like(beta) + bwd_prepare_wy_repr_kernel[(NT, b*h)]( + k, v, beta, + o_cumdecay, v_new, do, do2, + dk, dv, dbeta, + NT, d_k, d_v, l, chunk_size, BK, BV + ) + return dk, dv, dbeta + +class WYRepresentationPrepration(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, k, v, beta, chunk_size): + o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + ctx.chunk_size = chunk_size + ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new) + return o_cumdecay, v_new + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do, do2): + k, v, beta, o_cumdecay, v_new = ctx.saved_tensors + dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size) + return dk, dv, dbeta, None + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta, chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + # pad k, v, beta + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + # k = torch.nn.functional.normalize(k, dim=-1, p=2) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) + k_beta = k * beta[..., None] + v = v * beta[..., None] + attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) + attn = attn * beta[..., None] + x = attn @ v + + o = torch.zeros_like(k) + o2 = torch.zeros_like(v) + + o[..., 0, :] = k_beta[..., 0, :].clone() + o2[..., 0, :] = x[..., 0, :].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i, :]).clone() + o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] + o2_i = (o2[..., :i, :]).clone() + o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] + return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) + + +if __name__ == "__main__": + torch.set_default_dtype(torch.bfloat16) + seq_len = 2048 + b = 4 + h = 8 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2) + v = torch.randn(b, h, seq_len, 256) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta)) + do = torch.rand_like(k) + do2 = torch.rand_like(v) + + print("Start warmup.") + o1, o2 = prepare_wy_repr(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + o3, o4 = prepare_wy_repr2(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + print((o1 - o3).abs().max()) + print((o2 - o4).abs().max()) + + + for i in range(30): + o1, o2 = prepare_wy_repr(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + o1, o2 = prepare_wy_repr2(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + + print("Done warmup.") + + import time + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) + + + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr2(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) + + + \ No newline at end of file diff --git a/finetune/lora/v6/fla/ops/delta_rule/wy_fast.py b/finetune/lora/v6/fla/ops/delta_rule/wy_fast.py new file mode 100644 index 0000000..750565e --- /dev/null +++ b/finetune/lora/v6/fla/ops/delta_rule/wy_fast.py @@ -0,0 +1,401 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from einops import rearrange +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + + b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1)) + b_A = b_A.to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb, allow_tf32=False) + p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb, allow_tf32=False) + p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta, A, + dw, du, + dk, dv, dbeta, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False) + b_dv = b_dv_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dv_beta * b_v, 1) + # store + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + b_A2 = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_A2 += tl.dot(b_k_beta, tl.trans(b_k), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk = b_dk_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + # store + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + b_A -= (tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]) + b_A2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_A2, 0) + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) + tl.debug_barrier() + + for i in range(BT-1, 0, -1): + mask = tl.arange(0, BT) == i + b_da = tl.sum(tl.where(mask[:, None], b_dA, 0), 0) + b_a = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_da2 = b_da + tl.sum(b_da[None, :] * b_A, 1) + b_dA = tl.where(mask[:, None], b_da2, b_dA) + b_dA += b_da[None, :] * b_a[:, None] + + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) + tl.debug_barrier() + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + + b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty),boundary_check=(0,)) + + +def fwd_prepare_wy_repr(k, v, beta, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + u = torch.empty_like(v) + w = torch.empty_like(k) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, BT, BK, BV + ) + return w, u, A + + + +def fwd_recompute_w_u(k, v, beta, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + u = torch.empty_like(v) + w = torch.empty_like(k) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, BT, BK, BV + ) + return w, u + + + + + +def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, A, + dw, du, + dk, dv, dbeta, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, BT, BK, BV + ) + return dk, dv, dbeta + + +class WYRepresentationPrepration(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, k, v, beta, chunk_size): + ctx.BT = chunk_size + w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT) + ctx.save_for_backward(k, v, beta, A) + return w, u + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, dw, du): + k, v, beta, A = ctx.saved_tensors + BT = ctx.BT + dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT) + return dk, dv, dbeta, None + + + + +prepare_wy_repr = WYRepresentationPrepration.apply + +def naive(k, v, beta, chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + # pad k, v, beta + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + # k = torch.nn.functional.normalize(k, dim=-1, p=2) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) + k_beta = k * beta[..., None] + v = v * beta[..., None] + attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) + attn = attn * beta[..., None] + x = attn @ v + + o = torch.zeros_like(k) + o2 = torch.zeros_like(v) + + o[..., 0, :] = k_beta[..., 0, :].clone() + o2[..., 0, :] = x[..., 0, :].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i, :]).clone() + o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] + o2_i = (o2[..., :i, :]).clone() + o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] + return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) + + +if __name__ == "__main__": + torch.set_default_dtype(torch.float32) + seq_len = 1024 + b = 4 + h = 4 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2) + v = torch.randn(b, h, seq_len, 128) + beta = torch.rand(b, h, seq_len).sigmoid() + # beta = torch.ones(b, h, seq_len) + require_grad = True + + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta)) + do = torch.rand_like(k) + do2 = torch.rand_like(v) + + o1, o2 = naive(k.clone(), v.clone(), beta.clone(), 64) + if require_grad: + o1.backward(do, retain_graph=True) + o2.backward(do2, retain_graph=True) + + k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad + k.grad = v.grad = beta.grad = None + + o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone()) + print((o1-o3).abs().max()) + print((o2-o4).abs().max()) + + if require_grad: + o3.backward(do, retain_graph=True) + o4.backward(do2, retain_graph=True) + k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad + print((k_grad2-k_grad).abs().max()) + print((v_grad2-v_grad).abs().max()) + print((beta_grad2-beta_grad).abs().max()) + breakpoint() + diff --git a/finetune/lora/v6/fla/ops/gla/__init__.py b/finetune/lora/v6/fla/ops/gla/__init__.py new file mode 100644 index 0000000..f1fdb95 --- /dev/null +++ b/finetune/lora/v6/fla/ops/gla/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_gla +from .chunk_fuse import fused_chunk_gla +from .recurrent_fuse import fused_recurrent_gla + +__all__ = [ + 'chunk_gla', + 'fused_chunk_gla', + 'fused_recurrent_gla' +] diff --git a/finetune/lora/v6/fla/ops/gla/chunk.py b/finetune/lora/v6/fla/ops/gla/chunk.py new file mode 100644 index 0000000..7c83529 --- /dev/null +++ b/finetune/lora/v6/fla/ops/gla/chunk.py @@ -0,0 +1,734 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import chunk_reversed_cumsum_fwd +from fla.utils import contiguous + + +@triton.autotune( + configs=[ + triton.Config({'BS': 16}, num_warps=2), + triton.Config({'BS': 16}, num_warps=4), + triton.Config({'BS': 16}, num_warps=8), + triton.Config({'BS': 32}, num_warps=2), + triton.Config({'BS': 32}, num_warps=4), + triton.Config({'BS': 32}, num_warps=8), + triton.Config({'BS': 64}, num_warps=2), + triton.Config({'BS': 64}, num_warps=4), + triton.Config({'BS': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def chunk_gla_fwd_kernel_cum( + s, + o, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gla_fwd_kernel_h( + k, + v, + g, + h, + h0, + ht, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BK, BT] + b_g = tl.load(p_g, boundary_check=(0, 1)) + if i_t < NT - 1: + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + else: + b_gn = tl.min(b_g, axis=1) + b_h *= tl.exp(b_gn)[:, None] + b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype) + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gla_fwd_kernel_intra( + q, + k, + g, + A, + s_k_h, + s_k_t, + s_k_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_qg, b_kg, allow_tf32=False) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1) + b_A = tl.where(o_i >= j, b_A, 0.) + tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A) + + p_k = tl.advance(p_k, (K,)) + p_gk = tl.advance(p_gk, (K,)) + + +@triton.jit +def chunk_gla_fwd_kernel_inter( + q, + v, + g, + h, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + # [BT, BK] + b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gla_bwd_kernel_dh( + q, + g, + do, + dh, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_dh *= tl.exp(b_gn)[:, None] + # [BK, BT] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_q = (b_q * tl.exp(b_g)).to(b_q.dtype) + + # [BK, BV] + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + +@triton.jit +def chunk_gla_bwd_kernel_inter( + k, + v, + h, + g, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + b_dq = b_dq * tl.exp(b_gk) + b_dk = b_dk * b_gn + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gla_bwd_kernel_intra( + q, + k, + g, + dA, + dq, + dk, + dg, + s_k_h, + s_k_t, + s_k_d, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg, allow_tf32=False) + b_dq *= tl.exp(b_g - b_gn[None, :]) + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + p_gkj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.) + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False) + b_dk *= tl.exp(b_gn[None, :] - b_gk) + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_gqj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.) + + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1)) + b_dg = b_q * b_dq - b_k * b_dk + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkGLAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level): + B, H, T, K, V = *q.shape, v.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_gla_fwd_kernel_h[grid]( + k, v, g, h, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + final_state = None + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float) + + g_org, g = g, torch.empty_like(g, dtype=torch.float) + def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H)) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_gla_fwd_kernel_cum[grid]( + g_org, g, + g.stride(1), g.stride(2), g.stride(3), + T=T, S=K, BT=BT + ) + h = fwd_inner( + q=q, k=k, v=v, g=g, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + h0=initial_state if initial_state is not None else None, + ht=final_state if final_state is not None else None + ) + A = q.new_zeros(NK, B, H, T, BT) + grid = (NK, NT * NC * NC, B * H) + chunk_gla_fwd_kernel_intra[grid]( + q, k, g, A, + k.stride(1), k.stride(2), k.stride(3), + scale, + T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + A = A.sum(0, dtype=A.dtype) + o = torch.empty_like(v) + grid = (NV, NT, B * H) + chunk_gla_fwd_kernel_inter[grid]( + q, v, g, h, o, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + if checkpoint_level >= 1: + del g + g = g_org + if checkpoint_level > 1: + del h + h, initial_state = None, None + + ctx.save_for_backward(q, k, v, g, h, initial_state, A) + ctx.BT = BT + ctx.scale = scale + ctx.checkpoint_level = checkpoint_level + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, dht=None): + q, k, v, g, h, initial_state, A = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_gla_fwd_kernel_h[grid]( + k, v, g, h, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + def bwd_inner(q, g, do, B, H, T, K, V, BT, BK, BV, NT, scale): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_gla_bwd_kernel_dh[grid]( + q, g, do, dh, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), dh.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + return dh + + if ctx.checkpoint_level >= 1: + # save the original g and compute its fp32 cumsum during the backward pass for memory consideration + g_org, g = g, torch.zeros_like(g, dtype=torch.float) + def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H)) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_gla_fwd_kernel_cum[grid]( + g_org, g, + g.stride(1), g.stride(2), g.stride(3), + T=T, S=K, BT=BT + ) + + # rerun the forward pass to get h if checkpoint_level >= 1 + if ctx.checkpoint_level > 1: + h = fwd_inner( + q=q, k=k, v=v, g=g, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + h0=initial_state if initial_state is not None else None, + ht=None + ) + + scale = ctx.scale + dh = bwd_inner( + q, g, do, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + scale=scale + ) + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + dg = torch.empty_like(k, dtype=torch.float) + dv = v.new_empty(NK, *v.shape) + dA = q.new_zeros(B, H, T, BT) + grid = (NK, NT, B * H) + chunk_gla_bwd_kernel_inter[grid]( + k, v, h, g, A, do, dh, dq, dk, dv, dA, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0, dtype=dv.dtype) + grid = (NK, NT * NC, B * H) + chunk_gla_bwd_kernel_intra[grid]( + q, k, g, dA, dq, dk, dg, + k.stride(1), k.stride(2), k.stride(3), + T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + + dq = dq.to(q.dtype) + dk = dk.to(q.dtype) + # reversed cumsum, equivalent to: + # + # def reversed_cumsum(x, dim=-1): + # c = x.cumsum(dim) + # return x + c.index_select(dim, x.new_tensor([c.shape[dim]-1], dtype=torch.long)) - c + dg = chunk_reversed_cumsum_fwd(dg).to(k.dtype) + return dq, dk, dv, dg, None, None, None, None + + +def chunk_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: Optional[int] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + checkpoint_level: Optional[int] = 2 +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + g (torch.Tensor): + Forget gates of shape `(B, H, T, K)` applied to keys. + scale (Optional[int]): + Scale factor for the GLA attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + checkpoint_level (Optional[int]): + Checkpointing level; higher values will save more memories and do more recomputations during backward. + Default: `0`: + - Level `0`: no memory saved, no recomputation. + - Level `1`: recompute the fp32 cumulative values during backward. + - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward. + """ + assert checkpoint_level in [0, 1, 2] + if scale is None: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/gla/chunk_fuse.py b/finetune/lora/v6/fla/ops/gla/chunk_fuse.py new file mode 100644 index 0000000..affbcf2 --- /dev/null +++ b/finetune/lora/v6/fla/ops/gla/chunk_fuse.py @@ -0,0 +1,548 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Songlin Yang +# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635 +# on-the-fly computation without materializing hidden statets into HBMs + +from typing import Tuple + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange +from packaging import version +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.gla.chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum, + prepare_qg_kg) +from fla.utils import contiguous + +inv_ln2 = 1.44269504 + +@triton.jit +def fused_chunk_gla_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_K] + v, # value [B, H, L, D_head_V] + g, # cumulative sum of log decay [B, H, L, D_head_K] + o, # output [B, H, L, D_head_V] + + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < DK + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32) + if CHECK and i == 0: + b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) + else: + b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_db += BT * DK + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_gla_bwd_kernel( + q, k, v, g, + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < DK + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # b_g = tl.load(p_g, boundary_check=(0, 1)) * inv_ln2 + d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32) + + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [DV, DK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + + # cum = tl.zeros([BK], dtype=tl.float32) + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32) + + # inter-chunk + # [DK, DV] + if CHECK and i == 1: + b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False)) + b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False) + b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) + else: + b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False)) + b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False) + b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def fwd_inner_chunk( + q, k, g, A, + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + DK: tl.constexpr, # D_head_K +): + + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_k = tl.load(p_k, boundary_check=(0, 1)) + + p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < DK + o_i = tl.arange(0, BT) + + p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK) + p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK) + p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT) + + for i in range(BT): + _q = tl.load(p_q, mask=mask, other=0) * scale + gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32) + s = _q[None, :] * b_k * tl.math.exp2(gq[None, :] - b_g) + score = tl.sum(s, axis=1) + score = tl.where(o_i <= i, score, 0) + tl.store(p_A, score.to(p_A.dtype.element_ty)) + p_q += DK + p_gq += DK + p_A += BT + + +@triton.jit +def bwd_inner_chunk( + q, + k, + g, + dA, + dq, + dk, + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + DK: tl.constexpr, # D_head_K +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < DK + o_i = tl.arange(0, BT) + + p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK) + p_dq = dq + (i_bh) * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK) + p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK) + p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT) + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + + for i in range(BT): + _q = tl.load(p_q, mask=mask, other=0) + gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32) + score = tl.math.exp2(gq[None, :] - b_g) + score = tl.where(o_i[:, None] <= i, score, 0) + _dA = tl.load(p_dA) + _dA = tl.where(o_i <= i, _dA, 0) + b_dk += (_dA[:, None] * score * _q[None, :]) + b_dq = tl.sum(_dA[:, None] * score * b_k, axis=0) + tl.store(p_dq, b_dq, mask=mask) + p_q += DK + p_dq += DK + p_gq += DK + p_dA += BT + + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkGLAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, g, scale, initial_state, output_final_state): + ctx.g_dtype = g.dtype + g_original = g + # cumulative decay should be in float32, otherwise the err will be accumulated and amplified. + g = torch.empty_like(g, dtype=torch.float32) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + ctx.scale = scale + + # inter-chunk + BT = 16 # chunk_size + BK, BV = min(d_head_qk, 64), min(d_head_v, 64) + num_stages = 1 + num_warps = 2 + + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + q_g = torch.empty_like(q) + k_g = torch.empty_like(k) + grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads) + fwd_decay_cumsum[grid]( + g_original, + g, + q.stride(1), q.stride(2), q.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, BK=BK, DK=d_head_qk, num_warps=1 + ) + prepare_qg_kg[grid]( + q, k, g, q_g, k_g, + q.stride(1), q.stride(2), q.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, BK=BK, DK=d_head_qk, num_warps=1 + ) + + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float, requires_grad=False) + else: + final_state = None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, batch_size * n_heads) + fused_chunk_gla_fwd_kernel[grid]( + q_g, k_g, v, g, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + + # intra-chunk + chunk_size = 16 + num_chunk = seq_len // chunk_size + v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) + BK = min(d_head_qk, 64) + NK = triton.cdiv(d_head_qk, BK) + A = q.new_empty(NK, batch_size, n_heads, triton.cdiv(seq_len, BT), BT, BT) + grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads) + fwd_inner_chunk[grid]( + q, k, g, A, + q.stride(1), q.stride(2), q.stride(3), + batch_size, n_heads, seq_len, scale, BT=BT, BK=BK, DK=d_head_qk, num_stages=3, + num_warps=4 + ) + A = A.sum(0) + o2 = A @ v2 + o2 = rearrange(o2, 'b h n c d -> b h (n c) d') + # combine inner and inter + o.add_(o2) + ctx.save_for_backward(q, k, v, g_original, A, initial_state) + ctx.CHECK = CHECK + return o.to(v), final_state + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do, d_final_state=None): + q, k, v, g_origin, A, initial_state = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = ctx.scale + + # recomputation + # inter-chunk + BT = 16 # chunk_size + g = torch.empty_like(g_origin, dtype=torch.float32) + BK, BV = min(d_head_qk, 64), min(d_head_v, 64) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + q_g = torch.empty_like(q) + k_g = torch.empty_like(k) + grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads) + fwd_decay_cumsum[grid]( + g_origin, + g, + q.stride(1), q.stride(2), q.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, BK=BK, DK=d_head_qk, num_warps=1 + ) + prepare_qg_kg[grid]( + q, k, g, q_g, k_g, + q.stride(1), q.stride(2), q.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, BK=BK, DK=d_head_qk, num_warps=1 + ) + + # inter-chunk + BT = 16 + BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 2 + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + + grid = (NV, NK, batch_size * n_heads) + + fused_chunk_gla_bwd_kernel[grid]( + q_g, k_g, v, g, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + # clamp_min=-3, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages, + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + + # intra chunk + num_chunk = seq_len // BT + v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) + do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk) + dA2 = (do2 @ v2.transpose(-2, -1)) * scale + dv2 = A.transpose(-1, -2) @ do2 + dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk) + + BK = min(triton.next_power_of_2(d_head_qk), 16) + NK = triton.cdiv(d_head_qk, BK) + dk2 = torch.empty_like(k) + dq2 = torch.empty_like(q) + + grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads) + bwd_inner_chunk[grid]( + q, k, g, + dA2, dq2, dk2, + q.stride(1), q.stride(2), q.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, BK=BK, + num_warps=1, + num_stages=3 + ) + + BK = min(triton.next_power_of_2(d_head_qk), 32) + NK = triton.cdiv(d_head_qk, BK) + dg = torch.empty_like(g, dtype=torch.float32) + grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads) + bwd_decay_global_cumsum[grid]( + dq2, dq, dk2, dk, q, k, g, dg, + q.stride(1), q.stride(2), q.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, BK=BK, + num_warps=1, + num_stages=1 + ) + dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT) + + def rev_cumsum_exclusive(x): + cumsum_x = x.cumsum(-2) + rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x + return rev_cumsum_x + + rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :]) + dg.add_(rev_cumsum_dg.unsqueeze(-2)) + dv.add_(dv2) + dg = rearrange(dg, 'b h n c d -> b h (n c) d') + + return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None + + +def pad(x, chunk_size=16): + seq_len = x.shape[-2] + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) + + return x + + +def ceildiv(a, b): + return -(a // -b) + + +def fused_chunk_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: int = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + seq_len = q.shape[-2] + q, k, v, g = map(lambda x: pad(x), [q, k, v, g]) + o, final_state = FusedChunkGLAFunction.apply( + q, k, v, g, scale, initial_state, output_final_state) + o = o[..., :seq_len, :] + return o, final_state diff --git a/finetune/lora/v6/fla/ops/gla/chunk_util.py b/finetune/lora/v6/fla/ops/gla/chunk_util.py new file mode 100644 index 0000000..ba9db38 --- /dev/null +++ b/finetune/lora/v6/fla/ops/gla/chunk_util.py @@ -0,0 +1,138 @@ +import triton +import triton.language as tl + +inv_ln2 = 1.44269504 + + + +@triton.jit +def fwd_decay_cumsum( + g, + g_o, + s_qk_h, + s_qk_t, + s_qk_d, + B, + H, + T, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + DK: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) + p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) + cum_decay = tl.zeros([BK], dtype=tl.float32) + mask = (i_k * BK + tl.arange(0, BK)) < DK + + for i in range(BT): + _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + cum_decay += _g * inv_ln2 + tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask) + p_g += DK + p_go += DK + +@triton.jit +def prepare_qg_kg( + q, + k, + g, + qg, + kg, + s_qk_h, + s_qk_t, + s_qk_d, + B, + H, + T, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + DK: tl.constexpr +): + + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) + p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) + p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) + p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) + + mask = (i_k * BK + tl.arange(0, BK)) < DK + + last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK)) + + for i in range(BT): + _q = tl.load(p_q, mask=mask, other=0) + _k = tl.load(p_k, mask=mask, other=0) + _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + _q *= tl.math.exp2(_g) * scale + _k *= tl.math.exp2(last_decay - _g) + tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask) + tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask) + p_q += DK + p_g += DK + p_k += DK + p_kg += DK + p_qg += DK + + +@triton.jit +def bwd_decay_global_cumsum( + dq_inner, + dq_inter, + dk_inner, + dk_inter, + q, k, g, dg, + s_qk_h, + s_qk_t, + s_qk_d, + B, + H, + T, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + DK: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK + p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK + p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK + p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK + p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK + p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK + p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK + cum_grad_dg = tl.zeros([BK], dtype=tl.float32) + mask = (i_k * BK + tl.arange(0, BK)) < DK + last_g = tl.zeros([BK], dtype=tl.float32) + for j in range(BT-1, -1, -1): + _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + if j == (BT-1): + last_g = _g + _dq1 = tl.load(p_dq_inner, mask=mask, other=0) + _dq2 = tl.load(p_dq_inter, mask=mask, other=0) + _dq2 *= tl.math.exp2(_g) + _dq = _dq1 + _dq2 + tl.store(p_dq_inter, _dq, mask=mask) + _dk1 = tl.load(p_dk_inner, mask=mask, other=0) + _dk2 = tl.load(p_dk_inter, mask=mask, other=0) + _dk2 *= tl.math.exp2(last_g - _g) + _dk = _dk1 + _dk2 + tl.store(p_dk_inter, _dk, mask=mask) + _q = tl.load(p_q, mask=mask, other=0) + _k = tl.load(p_k, mask=mask, other=0) + _dg = _dq * _q - _dk * _k + cum_grad_dg += _dg + tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask) + p_g -= DK + p_k -= DK + p_q -= DK + p_dq_inner -= DK + p_dk_inner -= DK + p_dq_inter -= DK + p_dk_inter -= DK + p_dg -= DK + diff --git a/finetune/lora/v6/fla/ops/gla/naive.py b/finetune/lora/v6/fla/ops/gla/naive.py new file mode 100644 index 0000000..b8cf03b --- /dev/null +++ b/finetune/lora/v6/fla/ops/gla/naive.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn.functional as F + +from fla.ops.gla.recurrent_fuse import fused_recurrent_gla + + +def ceildiv(a, b): + return -(a // -b) + + +def naive_recurrent_gla( + q, + k, + v, + gk, + initial_state=None, + output_final_state=False, + causal=True +): + orig_dtype = q.dtype + q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk)) + batch_size, n_heads, seq_len, d_head_k = q.shape + _, _, _, d_head_v = v.shape + h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) + o = torch.zeros_like(v) + scale = d_head_k ** -0.5 + + if initial_state is not None: + h += initial_state + + for i in range(seq_len): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + gk_i = gk[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h = h * gk_i[..., None] + kv_i + o_i = (q_i[..., None] * h).sum(-2) + o[:, :, i] = o_i + + if causal: + return o.to(orig_dtype), h + else: + o_reverse = torch.zeros_like(v) + h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) + for i in range(seq_len-1, -1, -1): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + gk_i = gk[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h = h * gk_i[..., None] + kv_i + o_i = (q_i[..., None] * h).sum(-2) + o_reverse[:, :, i] = o_i + + return o, o_reverse + + +if __name__ == "__main__": + B = 4 + H = 4 + L = 512 + D = 128 + dtype = torch.float32 + q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True) + k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True) + v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True) + g = F.logsigmoid(torch.rand(B, H, L, D)).cuda( + ).clamp_min(-1).to(torch.float32).requires_grad_(True) + + do = torch.rand_like(v).cuda() + do2 = torch.rand_like(v).cuda() + intial_state = torch.rand(B, H, D, D).cuda() + + ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False) + + ref.backward(do, retain_graph=True) + ref_rev.backward(do2, retain_graph=True) + + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg, g.grad = g.grad.clone(), None + + tri, tri_rev = fused_recurrent_gla( + q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False) + tri.backward(do, retain_graph=True) + tri_rev.backward(do2, retain_graph=True) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), None + + assert ref.allclose(tri, 0, 1e-5), breakpoint() + assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint() + assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint() + assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint() + assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint() + assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint() + + # tri = fused_chunk_gla(q, k, v, g) + # tri.backward(do, retain_graph=True) + # tri_dq, q.grad = q.grad.clone(), None + # tri_dk, k.grad = k.grad.clone(), None + # tri_dv, v.grad = v.grad.clone(), None + # tri_dg, g.grad = g.grad.clone(), None + + # assert ref.allclose(tri, 0, 1e-5), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint() + # assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint() + # breakpoint() + print("Pass") diff --git a/finetune/lora/v6/fla/ops/gla/recurrent_fuse.py b/finetune/lora/v6/fla/ops/gla/recurrent_fuse.py new file mode 100644 index 0000000..ea14dca --- /dev/null +++ b/finetune/lora/v6/fla/ops/gla/recurrent_fuse.py @@ -0,0 +1,404 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_gla_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_K] + v, # value [B, H, L, D_head_V] + gk, # log gate [B, H, L, D_head_K] + gv, # log gate [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + # initial hidden state initialization [B, H, D_head_K, D_head_V] + initial_state, + final_state, # final hidden state [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction + USE_GK: tl.constexpr, # whether to use gk + USE_GV: tl.constexpr, # whether to use gv +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + p_k = k + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + p_v = v + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + + if USE_GK: + p_gk = gk + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < DK + mask_bv = (i_v * BV + tl.arange(0, BV)) < DV + + h = tl.zeros([BV, BK], dtype=tl.float32) + + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + if USE_INITIAL_STATE: + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + if USE_GK: + _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + h = h * _gk[None, :] + if USE_GV: + _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + h = h * _gv[:, None] + h += _k[None, :] * _v[:, None] + _o = h * _q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + p_q += -DK if REVERSE else DK + p_k += -DK if REVERSE else DK + p_o += -DV if REVERSE else DV + p_v += -DV if REVERSE else DV + if USE_GK: + p_gk += -DK if REVERSE else DK + if USE_GV: + p_gv += -DV if REVERSE else DV + + if STORE_FINAL_STATE: + p_final_s = final_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_gla_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + gk, # log gate [B, H, L, D_head_K] \alpha + gv, # log gate [B, H, L, D_head_V] \bete + + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + # initial hidden state initialization [B, H, D_head_K, D_head_V] + initial_state, + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction + USE_GK: tl.constexpr, # whether to use gk + USE_GV: tl.constexpr, # whether to use gv +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + p_k = k + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + p_v = v + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + p_do = do + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T-1) * DK if REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T-1) * DV if REVERSE else 0) + mask_bk = i_k * BK + tl.arange(0, BK) < DK + mask_bv = i_v * BV + tl.arange(0, BV) < DV + mask_kv = mask_bk[:, None] & mask_bv[None, :] + h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[:, None]) * \ + DV + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if USE_GK: + _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + h = h * _gk[:, None] + if USE_GV: + _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + h = h * _gv[None, :] + h += _k[:, None] * _v[None, :] + _d_q = h * _do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += -DK if REVERSE else DK + p_v += -DV if REVERSE else DV + p_q += -DK if REVERSE else DK + p_do += -DV if REVERSE else DV + p_dq += -DK if REVERSE else DK + if USE_GK: + p_gk += -DK if REVERSE else DK + if USE_GV: + p_gv += -DV if REVERSE else DV + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) + p_k = k + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) + p_do = do + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) + p_v = v + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \ + BK + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \ + BV + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) + if USE_GK: + p_gk = gk + i_bh * s_qk_h + i_k * BK + \ + tl.arange(0, BK) + ((T - 1) * DK if not REVERSE else 0) + if USE_GV: + p_gv = gv + i_bh * s_vo_h + i_v * BV + \ + tl.arange(0, BV) + ((T - 1) * DV if not REVERSE else 0) + + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + d_h += _q[:, None] * _do[None, :] + d_k = tl.sum(d_h * _v[None, :], axis=1) + d_v = tl.sum(d_h * _k[:, None], axis=0) + if USE_GK: + _gk = tl.load(p_gk, mask=mask_bk, other=0).to(tl.float32) + d_h *= _gk[:, None] + if USE_GV: + _gv = tl.load(p_gv, mask=mask_bv, other=0).to(tl.float32) + d_h *= _gv[None, :] + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_do += DV if REVERSE else -DV + p_q += DK if REVERSE else -DK + p_k += DK if REVERSE else -DK + p_v += DV if REVERSE else -DV + p_dk += DK if REVERSE else -DK + p_dv += DV if REVERSE else -DV + if USE_GK: + p_gk += DK if REVERSE else -DK + if USE_GV: + p_gv += DV if REVERSE else -DV + + +class FusedRecurrentGLAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, gk, gv, scale=None, initial_state=None, output_final_state=False, reverse=False): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + # default scale + if scale is None: + scale = d_head_qk ** -0.5 + if gk is not None: + gk = gk.float().exp() + if gv is not None: + gv = gv.float().exp() + + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + o = q.new_empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=torch.float32) + + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v) + else: + final_state = None + + grid = (NV, NK, batch_size * n_heads) + fused_recurrent_gla_fwd_kernel[grid]( + q, k, v, gk, gv, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + REVERSE=reverse, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, gk, gv, initial_state, o) + ctx.scale = scale + ctx.reverse = reverse + # we do not need the gradient of the final state from the next chunk + # similiar to Trunctated BPTT + if final_state is not None: + final_state = final_state.detach() + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do, d_final_state=None): + q, k, v, gk, gv, initial_state, o = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = ctx.scale + + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=torch.float32) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=torch.float32) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=torch.float32) + grid = (NV, NK, batch_size * n_heads) + + fused_recurrent_gla_bwd_kernel[grid]( + q, k, v, gk, gv, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None, + REVERSE=ctx.reverse, + USE_GK=gk is not None, + USE_GV=gv is not None + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + if gk is not None: + _dgk = dq * q.float() - dk * k.float() + if ctx.reverse: + dgk = _dgk.cumsum(-2) + else: + _dgk_cumsum = _dgk.cumsum(-2) + dgk = _dgk + _dgk_cumsum[:, :, -1, None] - _dgk_cumsum + else: + dgk = None + + if gv is not None: + _dgv = do.float() * o.float() - dv * v.float() + if ctx.reverse: + dgv = _dgv.cumsum(-2) + else: + _dgv_cumsum = _dgv.cumsum(-2) + dgv = _dgv + _dgv_cumsum[:, :, -1, None] - _dgv_cumsum + else: + dgv = None + + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dgk, dgv, None, None, None, None + + +# if scale is None, use d_head_qk ** -0.5 by default. Otherwise specify the scale yourself. e.g. scale = 1.0 +def fused_recurrent_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: torch.Tensor = None, + gv: torch.Tensor = None, + scale: int = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + causal: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + if causal: + o, final_state = FusedRecurrentGLAFunction.apply(q, k, v, gk, gv, scale, initial_state, output_final_state) + return o, final_state + else: + # do not support initial_state yet. looks very strange for bidirectional modeling + assert initial_state is None + assert output_final_state is False + o, final_state = FusedRecurrentGLAFunction.apply( + q, k, v, gk, gv, scale, initial_state, output_final_state, False) + o_reversed, final_state = FusedRecurrentGLAFunction.apply( + q, k, v, gk, gv, scale, initial_state, output_final_state, True) + return [o, o_reversed] diff --git a/finetune/lora/v6/fla/ops/hgrn/__init__.py b/finetune/lora/v6/fla/ops/hgrn/__init__.py new file mode 100644 index 0000000..96f24b1 --- /dev/null +++ b/finetune/lora/v6/fla/ops/hgrn/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_hgrn +from .recurrent_fuse import fused_recurrent_hgrn + +__all__ = [ + 'chunk_hgrn', + 'fused_recurrent_hgrn' +] diff --git a/finetune/lora/v6/fla/ops/hgrn/chunk.py b/finetune/lora/v6/fla/ops/hgrn/chunk.py new file mode 100644 index 0000000..6efb77c --- /dev/null +++ b/finetune/lora/v6/fla/ops/hgrn/chunk.py @@ -0,0 +1,373 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2024, Yu Zhang, Songlin Yang + +# this function implements the chunkwise form of HGRN, inspired by +# [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html) +# also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan + +# from tests on H800, with B, H, D = 16, 4, 128, we see that the chunk can be greatly faster than the recurrent: +# +# Performance: +# seq_len chunk recurrent chunk_bwd recurrent_bwd +# 0 128.0 0.039360 0.061056 0.312160 0.205008 +# 1 256.0 0.045824 0.123712 0.308784 0.297696 +# 2 512.0 0.058688 0.241952 0.310720 0.626528 +# 3 1024.0 0.088288 0.476992 0.313184 1.333152 +# 4 2048.0 0.169472 0.943264 0.452464 2.724864 +# 5 4096.0 0.329920 1.886144 0.881600 5.551520 +# 6 8192.0 0.647872 3.755040 1.740496 11.117184 +# 7 16384.0 1.272064 7.520576 3.446608 22.362528 + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import contiguous + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit +def chunk_hgrn_fwd_kernel_h( + x, + g, + gc, + o, + h0, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + i_bh * T * D + i_t * BT * D + o_d + p_g = g + i_bh * T * D + i_t * BT * D + o_d + p_gc = gc + i_bh * T * D + i_t * BT * D + o_d + p_o = o + i_bh * T * D + i_t * BT * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + b_gc = tl.zeros([BD], dtype=tl.float32) + if USE_INITIAL_STATE: + if i_t == 0: + b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32) + for i in range(0, BT): + mask_t = mask & ((i_t * BT + i) < T) + b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32) + b_h = tl.exp(b_g) * b_h + b_x + b_gc = b_gc + b_g + tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t) + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t) + + p_x += D + p_g += D + p_gc += D + p_o += D + + +@triton.jit +def chunk_hgrn_fwd_kernel_o( + gc, + o, + s_h, + s_t, + s_d, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_bh = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(1, tl.cdiv(T, BT)): + p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # [BD,] + b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32) + # [BT, BD] + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_o = b_o + tl.exp(b_gc) * b_h0[None, :] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit +def chunk_hgrn_bwd_kernel_h( + g, + gc, + dx, + do, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + BC = min(BT, T - i_t * BT) + NT = tl.num_programs(1) + + p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d + p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d + p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d + p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d + + if i_t == NT - 1: + b_gc = tl.zeros([BD], dtype=tl.float32) + else: + b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32) + b_dh = tl.zeros([BD], dtype=tl.float32) + for _ in range(BC - 1, -1, -1): + tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask) + + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + + b_gc = b_gc + b_g + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * tl.exp(b_g) + + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + + p_g -= D + p_gc -= D + p_dx -= D + p_do -= D + + +@triton.jit +def chunk_hgrn_bwd_kernel_o( + g, + gc, + o, + dx, + dg, + s_h, + s_t, + s_d, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_bh = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): + p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # [BD,] + mask_t = mask & ((i_t + 1) * BT < T) + b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32) + # [BT, BD] + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32) + b_dg = tl.load(p_dg, boundary_check=(0, 1)).to(tl.float32) + b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :] + b_dg = b_o * b_dx * tl.exp(b_g) + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkHGRNFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, x, g, initial_state=None, output_final_state=False): + B, H, T, D = x.shape + BT, BD = 128, min(64, triton.next_power_of_2(D)) + num_warps = 8 if BD == 64 else 4 + + gc = torch.empty_like(g, dtype=torch.float) + o = torch.empty_like(x, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H) + chunk_hgrn_fwd_kernel_h[grid]( + x, g, gc, o, initial_state, + T, D, + BT=BT, + USE_INITIAL_STATE=initial_state is not None + ) + def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) + chunk_hgrn_fwd_kernel_o[grid]( + gc, o, + o.stride(1), o.stride(2), o.stride(3), + T, D, + BT=BT, BD=BD, + num_warps=num_warps + ) + final_state = None + if output_final_state: + final_state = o[:, :, -1].clone() + o = o.to(x.dtype) + ctx.save_for_backward(g, o, initial_state) + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, dht=None): + g, o, initial_state = ctx.saved_tensors + B, H, T, D = do.shape + BT, BD = 128, min(64, triton.next_power_of_2(D)) + num_warps = 8 if BD == 64 else 4 + + gc = torch.empty_like(g, dtype=torch.float) + dx = torch.empty_like(o) + dg = torch.empty_like(g) + def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H) + chunk_hgrn_bwd_kernel_h[grid]( + g, gc, dx, do, + T, D, + BT=BT + ) + def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) + chunk_hgrn_bwd_kernel_o[grid]( + g, gc, o, dx, dg, + o.stride(1), o.stride(2), o.stride(3), + T, D, + BT=BT, BD=BD, + num_warps=num_warps + ) + if initial_state is not None: + dg[:, :, 0] = initial_state * dx[:, :, 0] * g[:, :, 0].exp() + + return dx, dg, None, None + + +def chunk_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = ChunkHGRNFunction.apply(x, g, initial_state, output_final_state) + return o, final_state + + +if __name__ == '__main__': + import torch.nn.functional as F + + from fla.ops.hgrn.naive import naive_recurrent_hgrn + from fla.ops.hgrn.recurrent_fuse import fused_recurrent_hgrn + B, H, T, D = 8, 4, 512, 128 + dtype = torch.bfloat16 + torch.manual_seed(42) + # [batch_size, n_heads, seq_len, d_head] + x = torch.randn((B, H, T, D), dtype=dtype, device='cuda') + g = torch.randn((B, H, T, D), dtype=dtype, device='cuda') + x, g = (1 - g.sigmoid()) * x, F.logsigmoid(g) + print(f'x:\t{float(x.min()):>10.6f}\t{float(x.max()):>10.6f}') + print(f'g:\t{float(g.min()):>10.6f}\t{float(g.max()):>10.6f}') + x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g)) + print(f"DTYPE:\t{x.dtype}") + do = torch.randn_like(x) + h0 = torch.randn_like(x[:, :, 0]) + ref, ref_ht = naive_recurrent_hgrn(x, g, h0, output_final_state=True) + ref.backward(do) + ref_dx, x.grad = x.grad.clone(), None + ref_dg, g.grad = g.grad.clone(), None + + tri, tri_ht = fused_recurrent_hgrn(x, g, h0, output_final_state=True) + tri.backward(do) + tri_dx, x.grad = x.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), None + print(" \t DIFF\t MAX") + print(' o\t', f"{float((ref - tri).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}") + print('ht\t', f"{float((ref_ht[0] - tri_ht[0]).abs().max()):>10.6f}\t{float(ref.max()):>10.6f}") + print('dx\t', f"{float((ref_dx - tri_dx).abs().max()):>10.6f}\t{float(ref_dx.max()):>10.6f}") + print('dg\t', f"{float((ref_dg - tri_dg).abs().max()):>10.6f}\t{float(ref_dg.max()):>10.6f}") + print('Done!') + + @triton.testing.perf_report( + triton.testing.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['seq_len'], + # different possible values for `x_name` + x_vals=[128 * 2 ** i for i in range(0, 8)], + # argument name whose value corresponds to a different line in the plot + line_arg='provider', + # possible values for `line_arg`` + line_vals=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'], + # label name for the lines + line_names=['chunk', 'recurrent', 'chunk_bwd', 'recurrent_bwd'], + # line styles + styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')], + ylabel="Execution Time (ms)", # label name for the y-axis + # name for the plot. Used also as a file name for saving the plot. + plot_name="Performance", + args={}, + ) + ) + def benchmark(seq_len, provider): + dtype = torch.bfloat16 + B, H, D = 16, 4, 128 + + x = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda') + g = torch.randn((B, H, seq_len, D), dtype=dtype, device='cuda').sigmoid() + x = (1 - g) * x + x, g = (i.detach().clone().to(dtype).requires_grad_() for i in (x, g)) + do = torch.randn_like(x, dtype=dtype) + quantiles = [0.5, 0.2, 0.8] + results = 0, 0, 0 + if provider == 'chunk': + results = triton.testing.do_bench(lambda: chunk_hgrn(x, g), quantiles=quantiles) + if provider == 'recurrent': + results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g), quantiles=quantiles) + if provider == 'chunk_bwd': + results = triton.testing.do_bench(lambda: chunk_hgrn(x, g)[0].backward(do), quantiles=quantiles) + if provider == 'recurrent_bwd': + results = triton.testing.do_bench(lambda: fused_recurrent_hgrn(x, g)[0].backward(do), quantiles=quantiles) + return results + benchmark.run(print_data=True) diff --git a/finetune/lora/v6/fla/ops/hgrn/naive.py b/finetune/lora/v6/fla/ops/hgrn/naive.py new file mode 100644 index 0000000..0d79cb9 --- /dev/null +++ b/finetune/lora/v6/fla/ops/hgrn/naive.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = x.dtype + x, g = map(lambda i: i.float(), (x, g)) + B, H, T, D = x.shape + + h = torch.zeros(B, H, D, dtype=torch.float, device=x.device) + o = torch.zeros_like(x) + + final_state = None + if initial_state is not None: + h += initial_state.detach() + + for i in range(T): + h = g[:, :, i].exp() * h + x[:, :, i] + o[:, :, i] = h + + if output_final_state: + final_state = h + return o.to(dtype), final_state diff --git a/finetune/lora/v6/fla/ops/hgrn/recurrent_fuse.py b/finetune/lora/v6/fla/ops/hgrn/recurrent_fuse.py new file mode 100644 index 0000000..82224d6 --- /dev/null +++ b/finetune/lora/v6/fla/ops/hgrn/recurrent_fuse.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import contiguous + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit +def fused_recurrent_hgrn_fwd_kernel( + x, + g, + o, + h0, + ht, + T: tl.constexpr, + D: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_d, i_bh = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + i_bh * T * D + o_d + p_g = g + i_bh * T * D + o_d + p_o = o + i_bh * T * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * D + o_d + b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32) + for _ in range(0, T): + b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_h = tl.exp(b_g) * b_h + b_x + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask) + + p_x += D + p_g += D + p_o += D + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * D + o_d + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit +def fused_recurrent_hgrn_bwd_kernel( + g, + o, + dx, + dg, + do, + h0, + T: tl.constexpr, + D: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr +): + i_d, i_bh = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_g = g + (i_bh * T + T - 1) * D + o_d + p_o = o + (i_bh * T + T - 2) * D + o_d + p_dx = dx + (i_bh * T + T - 1) * D + o_d + p_dg = dg + (i_bh * T + T - 1) * D + o_d + p_do = do + (i_bh * T + T - 1) * D + o_d + + b_dh = tl.zeros([BD], dtype=tl.float32) + for i in range(T - 1, -1, -1): + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + if i > 0: + b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32) + elif USE_INITIAL_STATE: + b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32) + else: + b_o = tl.zeros([BD], dtype=tl.float32) + + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * tl.exp(b_g) + b_dg = b_dh * b_o + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask) + + p_g -= D + p_o -= D + p_dx -= D + p_dg -= D + p_do -= D + + +class FusedRecurrentHGRNFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, x, g, initial_state=None, output_final_state=False): + B, H, T, D = x.shape + + final_state = None + if output_final_state: + final_state = x.new_empty(B, H, D) + + o = torch.empty_like(x) + def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) + fused_recurrent_hgrn_fwd_kernel[grid]( + x, g, o, initial_state, final_state, + T, D, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None + ) + ctx.save_for_backward(g, o, initial_state) + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, dht=None): + g, o, initial_state = ctx.saved_tensors + B, H, T, D = do.shape + + dx = torch.empty_like(o) + dg = torch.empty_like(g) + def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) + fused_recurrent_hgrn_bwd_kernel[grid]( + g, o, dx, dg, do, initial_state, + T, D, + USE_INITIAL_STATE=initial_state is not None, + ) + + return dx, dg, None, None + + +def fused_recurrent_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/linear_attn/__init__.py b/finetune/lora/v6/fla/ops/linear_attn/__init__.py new file mode 100644 index 0000000..4563a5f --- /dev/null +++ b/finetune/lora/v6/fla/ops/linear_attn/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_linear_attn +from .chunk_fuse import fused_chunk_linear_attn +from .recurrent_fuse import fused_recurrent_linear_attn + +__all__ = [ + 'chunk_linear_attn', + 'fused_chunk_linear_attn', + 'fused_recurrent_linear_attn' +] + diff --git a/finetune/lora/v6/fla/ops/linear_attn/chunk.py b/finetune/lora/v6/fla/ops/linear_attn/chunk.py new file mode 100644 index 0000000..c474cd1 --- /dev/null +++ b/finetune/lora/v6/fla/ops/linear_attn/chunk.py @@ -0,0 +1,359 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + + +@torch.jit.script +def normalize_output(q, k, o): + k = k.transpose(-2, -1) + k = k.cumsum(-1) + k = k.transpose(-2, -1) + z = (q * k).sum(-1, keepdim=True) + return o / (z + 1e-5) + + +@triton.jit +def chunk_linear_attn_fwd_kernel_h( + k, + v, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BK, BV] + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale + p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_linear_attn_bwd_kernel_dh( + q, + do, + dh, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + + +@triton.jit +def chunk_linear_attn_bwd_kernel_dqkv( + q, + k, + v, + h, + do, + dh, + dq, + dk, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + o_i = tl.arange(0, BT) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype) + # [BT, BK] + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @custom_fwd + @contiguous + def forward(ctx, q, k, v, scale, initial_state, output_final_state): + B, H, T, K, V = *q.shape, v.shape[-1] + BT = 64 + BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V)) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + ctx.scale = scale + + final_state = None + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False) + + h = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_linear_attn_fwd_kernel_h[grid]( + k, v, h, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NV, NT, B * H) + o = torch.empty_like(v) + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v, h, o, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + ctx.save_for_backward(q, k, v, h) + return o.to(q.dtype), final_state + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, d_ht=None): + q, k, v, h = ctx.saved_tensors + + B, H, T, K, V = *q.shape, v.shape[-1] + BT = 64 + BK, BV = min(64, triton.next_power_of_2(K)), min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V)) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + scale = ctx.scale + + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_linear_attn_bwd_kernel_dh[grid]( + q, do, dh, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + + grid = (NK, NT, B * H) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = v.new_empty(NK, *v.shape) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + chunk_linear_attn_bwd_kernel_dqkv[grid]( + q, k, v, h, do, dh, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None + + +def chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = ChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + + if normalize: + o = normalize_output(q * scale, k, o) + + return o, final_state diff --git a/finetune/lora/v6/fla/ops/linear_attn/chunk_fuse.py b/finetune/lora/v6/fla/ops/linear_attn/chunk_fuse.py new file mode 100644 index 0000000..0ca7101 --- /dev/null +++ b/finetune/lora/v6/fla/ops/linear_attn/chunk_fuse.py @@ -0,0 +1,326 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from packaging import version +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@torch.jit.script +def normalize_output(q, k, o): + k = k.transpose(-2, -1) + k = k.cumsum(-1) + k = k.transpose(-2, -1) + z = (q * k).sum(-1, keepdim=True) + return o / (z + 1e-5) + + +@triton.jit +def fused_chunk_linear_attn_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_linear_attn_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, DK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [DV, DK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # b_dd = (b_do]).to(b_do.dtype) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, DK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, DV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + tl.store(p_dk, (b_dk * scale).to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkLinearAttentionFunction(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, scale, initial_state, output_final_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + ctx.scale = scale + BT = 64 + BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 4 + o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + else: + final_state = None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, batch_size * n_heads) + fused_chunk_linear_attn_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + ctx.CHECK = CHECK + return o.to(q.dtype), final_state + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, d_final_state=None): + q, k, v, initial_state = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = ctx.scale + + BT = 64 + BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + + fused_chunk_linear_attn_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None + + +def fused_chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + if scale == -1: + scale = q.shape[-1] ** -0.5 + o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: + o = normalize_output(q * scale, k, o) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/linear_attn/naive.py b/finetune/lora/v6/fla/ops/linear_attn/naive.py new file mode 100644 index 0000000..60b3197 --- /dev/null +++ b/finetune/lora/v6/fla/ops/linear_attn/naive.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def torch_chunk_linear_attn(q, k, v, chunk_size=64): + q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] **-0.5) + k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) + kv = k.transpose(-1, -2) @ v + kv = kv.cumsum(2) + kv = torch.cat([ + torch.zeros_like(kv[:, :, :1]), + kv[:, :, :-1] + ], dim=2) + inter = q @ kv + intra = ((q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)) @ v + o = inter + intra + return rearrange(o, 'b h n c d -> b h (n c) d') diff --git a/finetune/lora/v6/fla/ops/linear_attn/recurrent_fuse.py b/finetune/lora/v6/fla/ops/linear_attn/recurrent_fuse.py new file mode 100644 index 0000000..20bd0fe --- /dev/null +++ b/finetune/lora/v6/fla/ops/linear_attn/recurrent_fuse.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@torch.jit.script +def normalize_output(q, k, o): + k = k.transpose(-2, -1) + k = k.cumsum(-1) + k = k.transpose(-2, -1) + z = (q * k).sum(-1, keepdim=True) + return o / (z + 1e-5) + + +@triton.jit +def fused_recurrent_linear_attn_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + initial_state, + final_state, # final hidden state [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < DK + mask_bv = (i_v * BV + tl.arange(0, BV)) < DV + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + + h += _k[None, :] * _v[:, None] + _o = h * _q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += DK + p_k += DK + p_o += DV + p_v += DV + + if STORE_FINAL_STATE: + p_final_s = final_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_linear_attn_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + # initial hidden state initialization [B, H, D_head_K, D_head_V] + initial_state, + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + mask_bk = i_k * BK + tl.arange(0, BK) < DK + mask_bv = i_v * BV + tl.arange(0, BV) < DV + + h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[:, None]) * \ + DV + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + + h += _k[:, None] * _v[None, :] + _d_q = h * _do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += DK + p_do += DV + p_v += DV + p_dq += DK + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \ + BK + tl.arange(0, BK) + (T - 1) * DK + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \ + BV + tl.arange(0, BV) + (T - 1) * DV + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + d_h += _q[:, None] * _do[None, :] + d_k = tl.sum(d_h * _v[None, :], axis=1) + d_v = tl.sum(d_h * _k[:, None], axis=0) + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_do -= DV + p_q -= DK + p_k -= DK + p_v -= DV + p_dk -= DK + p_dv -= DV + + +class FusedRecurrentLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, initial_state=None, output_final_state=False): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + + scale = d_head_qk ** -0.5 + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v) + else: + final_state = None + + grid = (NV, NK, batch_size * n_heads) + fused_recurrent_linear_attn_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, d_final_state=None): + q, k, v, initial_state = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + + fused_recurrent_linear_attn_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq, dk, dv, None, None + + +def fused_recurrent_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = FusedRecurrentLinearAttentionFunction.apply( + q, k, v, initial_state, output_final_state) + if normalize: + o = normalize_output(q, k, o) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/rebased/__init__.py b/finetune/lora/v6/fla/ops/rebased/__init__.py new file mode 100644 index 0000000..6ec6a0c --- /dev/null +++ b/finetune/lora/v6/fla/ops/rebased/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .parallel import parallel_rebased + +__all__ = [ + 'parallel_rebased' +] diff --git a/finetune/lora/v6/fla/ops/rebased/naive.py b/finetune/lora/v6/fla/ops/rebased/naive.py new file mode 100644 index 0000000..77bdf56 --- /dev/null +++ b/finetune/lora/v6/fla/ops/rebased/naive.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + +from fla.ops.rebased.parallel import parallel_rebased + +def naive_parallel_rebased(q, k, v, use_scale=True, use_norm=True): + if use_scale: + q = q * (q.shape[-1] ** -0.5) + attn = q @ k.transpose(-2, -1) + attn = (attn ** 2) + attn.masked_fill_(~torch.tril(torch.ones( + q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) + o = attn @ v + if use_norm: + z = attn.sum(-1) + return o / (z[..., None] + 1e-6) + else: + return o + + +if __name__ == "__main__": + B = 4 + H = 4 + L = 128 + # D = 15 + dtype = torch.float32 + q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) + k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) + v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True) + + do = torch.randn_like(v).cuda() + ref = naive_parallel_rebased(q, k, v, True, True) + ref.backward(do, retain_graph=True) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + + # tri = naive_chunk_based(q, k, v) + # tri.backward(do, retain_graph=True) + # tri_dq, q.grad = q.grad.clone(), None + # tri_dk, k.grad = k.grad.clone(), None + # tri_dv, v.grad = v.grad.clone(), None + + # assert ref.allclose(tri, 0, 1e-4), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() + + tri = parallel_rebased(q, k, v, 1e-6, True, True) + tri.backward(do, retain_graph=True) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + print((ref-tri).abs().max()) + print((ref_dq-tri_dq).abs().max()) + print((ref_dk-tri_dk).abs().max()) + print((ref_dv-tri_dv).abs().max()) + + # assert ref.allclose(tri, 0, 1e-4), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() + + # tri = parallel_based(q, k, v, True, True) + # tri.backward(do, retain_graph=True) + # tri_dq, q.grad = q.grad.clone(), None + # tri_dk, k.grad = k.grad.clone(), None + # tri_dv, v.grad = v.grad.clone(), None + + # print((ref-tri).abs().max()) + # print((ref_dq-tri_dq).abs().max()) + # print((ref_dk-tri_dk).abs().max()) + # print((ref_dv-tri_dv).abs().max()) + + # assert ref.allclose(tri, 0, 1e-4), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() diff --git a/finetune/lora/v6/fla/ops/rebased/parallel.py b/finetune/lora/v6/fla/ops/rebased/parallel.py new file mode 100644 index 0000000..73920b5 --- /dev/null +++ b/finetune/lora/v6/fla/ops/rebased/parallel.py @@ -0,0 +1,387 @@ + +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + +# Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models +# https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py + + +@triton.jit +def parallel_rebased_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + z, # normalizer [B, H, L] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q + BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), + mask=((i_c * BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit +def _parallel_rebased_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_q = (b_q * scale).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype), + b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_rebased_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( + p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( + [BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \ + scale # [BTL, BTS] + b_s2 = b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, + (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, + (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def parallel_rebased_bwd_kernel( + q, k, v, do, dz, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_rebased_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV + ) + tl.debug_barrier() + _parallel_rebased_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV + ) + + +class ParallelBasedFunction(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, device=q.device) + z = torch.empty(NK, batch_size, n_heads, seq_len, + device=q.device) + parallel_rebased_fwd_kernel[grid]( + q, k, v, o, z, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=q.dtype, device=q.device) + + parallel_rebased_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +triton_parallel_based = ParallelBasedFunction.apply + + +def parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if use_scale: + scale = q.shape[-1] ** -0.5 + else: + scale = 1 + o, z = triton_parallel_based(q, k, v, scale) + if return_both: + return o, z + if use_normalize: + o = o / (z[..., None] + eps) + else: + o = o + return o.to(q.dtype) diff --git a/finetune/lora/v6/fla/ops/retention/__init__.py b/finetune/lora/v6/fla/ops/retention/__init__.py new file mode 100644 index 0000000..b7f29d7 --- /dev/null +++ b/finetune/lora/v6/fla/ops/retention/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_retention +from .chunk_fuse import fused_chunk_retention +from .parallel import parallel_retention +from .recurrent_fuse import fused_recurrent_retention + +__all__ = [ + 'chunk_retention', + 'fused_chunk_retention', + 'parallel_retention', + 'fused_recurrent_retention' +] diff --git a/finetune/lora/v6/fla/ops/retention/chunk.py b/finetune/lora/v6/fla/ops/retention/chunk.py new file mode 100644 index 0000000..0b162be --- /dev/null +++ b/finetune/lora/v6/fla/ops/retention/chunk.py @@ -0,0 +1,364 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + + +@triton.jit +def chunk_retention_fwd_kernel_h( + k, + v, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((BT - o_i - 1) * b_b) + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BK, BV] + if i_t == NT - 1 and (T % BT) != 0: + d_b = tl.math.exp2((T % BT) * b_b) + d_i = tl.math.exp2(((T % BT) - o_i - 1) * b_b) + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_i[:, None]).to(b_k.dtype), allow_tf32=False) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_retention_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_i = tl.math.exp2((o_i + 1) * b_b) + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot((b_q * d_i[:, None]).to(b_q.dtype), b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s *= d_s + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale + p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_retention_bwd_kernel_dh( + q, + do, + dh, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b) + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = d_b * b_dh + tl.dot(b_q, (b_do * d_i[:, None]).to(b_q.dtype), allow_tf32=False) + + +@triton.jit +def chunk_retention_bwd_kernel_dqkv( + q, + k, + v, + h, + do, + dh, + dq, + dk, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + n_bh = tl.num_programs(2) + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_q, d_k = tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b) + d_q = (d_q * scale).to(d_q.dtype) + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_s = tl.dot(b_k, b_q, allow_tf32=False) * tl.trans(d_s) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * d_k[:, None] + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_ds = (b_ds * d_s).to(b_q.dtype) + # [BT, BK] + b_dq = b_dq * d_q[:, None] + tl.dot(b_ds, b_k, allow_tf32=False) + b_dk = b_dk * d_k[:, None] + tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkRetentionFunction(torch.autograd.Function): + + @staticmethod + @custom_fwd + @contiguous + def forward(ctx, q, k, v, initial_state, output_final_state): + B, H, T, K, V = *q.shape, v.shape[-1] + BT = 64 + BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V)) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + scale = K ** -0.5 + + final_state = None + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False) + + h = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_retention_fwd_kernel_h[grid]( + k, v, h, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NV, NT, B * H) + o = torch.empty_like(v) + chunk_retention_fwd_kernel_o[grid]( + q, k, v, h, o, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + ctx.save_for_backward(q, k, v, h) + return o.to(q.dtype), final_state + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, d_ht=None): + q, k, v, h = ctx.saved_tensors + + B, H, T, K, V = *q.shape, v.shape[-1] + BT = 64 + BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V)) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + scale = K ** -0.5 + + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_retention_bwd_kernel_dh[grid]( + q, do, dh, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + + grid = (NK, NT, B * H) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = v.new_empty(NK, *v.shape) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + chunk_retention_bwd_kernel_dqkv[grid]( + q, k, v, h, do, dh, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None + + +def chunk_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = ChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/retention/chunk_fuse.py b/finetune/lora/v6/fla/ops/retention/chunk_fuse.py new file mode 100644 index 0000000..80af7f2 --- /dev/null +++ b/finetune/lora/v6/fla/ops/retention/chunk_fuse.py @@ -0,0 +1,334 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from packaging import version +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_chunk_retention_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + o_i = tl.arange(0, BT) + # decay rate given the head index + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + # d_b: overall decay for the entire chunk + # d_o: cumulative decay from the start of the chunk + # d_h: cumulative decay from the end of the chunk + d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + NT = tl.cdiv(T, BT) + for i in range(0, NT): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s + # [BT, BV] + b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None] + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None] + if i == NT - 1 and (T % BT) != 0: + d_b = tl.math.exp2((T % BT) * b_b) + d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b) + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_retention_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + o_i = tl.arange(0, BT) + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b) + d_b = tl.math.exp2(BT * b_b) + + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dd = (b_do * d_q[:, None]).to(b_do.dtype) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = (b_ds * d_s).to(b_k.dtype) + # [BT, DK] + b_dq = tl.dot(b_ds, b_k, allow_tf32=False) + # [DV, DK] + if CHECK and i == 0: + b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False) + b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False) + b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False) + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + d_s = tl.trans(d_s) + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dd = (b_do * d_q[:, None]).to(b_do.dtype) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = (b_ds * d_s).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s + # [BT, DK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, DV] + b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None] + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None] + b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None] + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None] + b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkRetentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v, initial_state, output_final_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + + scale = d_head_qk ** -0.5 + BT = 64 + BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 4 + + o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + else: + final_state = None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, batch_size * n_heads) + fused_chunk_retention_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + ctx.CHECK = CHECK + return o.to(q.dtype), final_state + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, d_final_state=None): + q, k, v, initial_state = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + + BT = 64 + BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + + fused_chunk_retention_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None + + +def fused_chunk_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/retention/naive.py b/finetune/lora/v6/fla/ops/retention/naive.py new file mode 100644 index 0000000..15611bf --- /dev/null +++ b/finetune/lora/v6/fla/ops/retention/naive.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +import torch + + +def naive_retention(q, k, v): + orig_type = q.dtype + q, k, v = q.float(), k.float(), v.float() + _, n_heads, seq_len, d_head = q.shape + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2() + n = q.new_tensor(range(seq_len), dtype=torch.float) + n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n) + s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype)) + o = torch.einsum('bhqk,bhkd->bhqd', s, v) + return o.to(orig_type) diff --git a/finetune/lora/v6/fla/ops/retention/parallel.py b/finetune/lora/v6/fla/ops/retention/parallel.py new file mode 100644 index 0000000..18c7268 --- /dev/null +++ b/finetune/lora/v6/fla/ops/retention/parallel.py @@ -0,0 +1,339 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + + +@triton.jit +def parallel_retention_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q + BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + # decay rate given the head index + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + # cumulative decay from the end of the chunk + o_k = tl.arange(0, BTS) + d_h = tl.math.exp2((BTS - o_k) * b_b) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) * d_h[None, :] + # [BQ, BD] + b_o = b_o * tl.math.exp2(b_b * BTS) + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + d_q = tl.math.exp2(tl.arange(0, BTL) * b_b) + b_o *= d_q[:, None] + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), + (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + d_s = tl.where(m_s, tl.math.exp2( + (o_q[:, None] - o_k[None, :]) * b_b), 0) + b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV), + (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def _parallel_retention_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) + # decay rate given the head index + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + # overall decay rate for an entire block + d_b = tl.math.exp2(b_b * BTS) + # cumulative decay from the end of the chunk + d_h = tl.math.exp2((BTS - tl.arange(0, BTS)) * b_b) + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_h[None, :] + # [BQ, BD] + b_dq *= d_b + b_dq += tl.dot(b_ds.to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + b_dq *= tl.math.exp2(tl.arange(0, BTL) * b_b)[:, None] * scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), + (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + d_s = tl.where(m_s, tl.math.exp2( + (o_q[:, None] - o_k[None, :]) * b_b), 0) + b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_s * scale + # [BTL, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_retention_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + # no overlap. no need for mask. + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + # overall decay rate for an entire block + d_b = tl.math.exp2(b_b * BTS) + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( + p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( + [BTL, BV], dtype=tl.float32) + d_h = tl.math.exp2((BTL - tl.arange(0, BTL)) * b_b) + b_kd = (b_k * d_h[:, None]).to(b_k.dtype) + d_q = tl.math.exp2(tl.arange(0, BTS) * b_b) + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)) # [BV, BTS] + b_do = (b_do * d_q[None, :]).to(b_do.dtype) + + b_dv *= d_b + b_s = tl.dot(b_kd.to(b_q.dtype), b_q, allow_tf32=False) # [BTL, BTS] + b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + + b_dk *= d_b + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + b_dk *= d_h[:, None] * scale + b_dv *= scale + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + d_s = tl.where(m_s, tl.math.exp2( + (-o_k[:, None] + o_q[None, :]) * b_b.to(tl.float32)), 0) * scale + b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * d_s + # [BK, BD] + b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + o_q += BTS + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, + (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, + (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def parallel_retention_bwd_kernel( + q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, + BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + DK: tl.constexpr, DV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(DV, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_retention_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV + ) + tl.debug_barrier() + _parallel_retention_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV + ) + + +class ParallelRetentionFunction(torch.autograd.Function): + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, q, k, v): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 3 if d_head_qk <= 64 else 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + scale = d_head_qk ** -0.5 + o = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=q.dtype, device=q.device) + parallel_retention_fwd_kernel[grid]( + q, k, v, o, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + return o.sum(0).to(q.dtype) + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do): + q, k, v = ctx.saved_tensors + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + num_stages = 3 if d_head_qk <= 64 else 2 + num_warps = 4 + NK = triton.cdiv(d_head_qk, BK) + NV = triton.cdiv(d_head_v, BV) + grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads) + scale = d_head_qk ** -0.5 + + dq = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, batch_size, n_heads, seq_len, + d_head_qk, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, batch_size, n_heads, seq_len, + d_head_v, dtype=q.dtype, device=q.device) + + parallel_retention_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype) + + +parallel_retention = ParallelRetentionFunction.apply diff --git a/finetune/lora/v6/fla/ops/retention/recurrent_fuse.py b/finetune/lora/v6/fla/ops/retention/recurrent_fuse.py new file mode 100644 index 0000000..f78b45f --- /dev/null +++ b/finetune/lora/v6/fla/ops/retention/recurrent_fuse.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_retention_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + initial_state, + final_state, # final hidden state [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + # decay rate given the head index + b_b = (1 - tl.math.pow(2, -5 - i_h * 1.0)) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < DK + mask_bv = (i_v * BV + tl.arange(0, BV)) < DV + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + + h = b_b * h + _k[None, :] * _v[:, None] + _o = h * _q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += DK + p_k += DK + p_o += DV + p_v += DV + + if STORE_FINAL_STATE: + p_final_s = final_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_retention_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + # initial hidden state initialization [B, H, D_head_K, D_head_V] + initial_state, + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + b_b = 1 - tl.math.pow(2, -5 - i_h * 1.0) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + mask_bk = i_k * BK + tl.arange(0, BK) < DK + mask_bv = i_v * BV + tl.arange(0, BV) < DV + + h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[:, None]) * \ + DV + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + + h = b_b * h + _k[:, None] * _v[None, :] + _d_q = h * _do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += DK + p_do += DV + p_v += DV + p_dq += DK + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \ + BK + tl.arange(0, BK) + (T - 1) * DK + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \ + BV + tl.arange(0, BV) + (T - 1) * DV + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + d_h += _q[:, None] * _do[None, :] + d_k = tl.sum(d_h * _v[None, :], axis=1) + d_v = tl.sum(d_h * _k[:, None], axis=0) + + d_h *= b_b + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_do -= DV + p_q -= DK + p_k -= DK + p_v -= DV + p_dk -= DK + p_dv -= DV + + +class FusedRecurrentRetentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, initial_state=None, output_final_state=False): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + + scale = d_head_qk ** -0.5 + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v) + else: + final_state = None + + grid = (NV, NK, batch_size * n_heads) + fused_recurrent_retention_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, d_final_state=None): + q, k, v, initial_state = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + + fused_recurrent_retention_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq, dk, dv, None, None + + +# fused_recurrent_retention = FusedRecurrentRetentionFunction.apply + +def fused_recurrent_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = FusedRecurrentRetentionFunction.apply(q, k, v, initial_state, output_final_state) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/rotary.py b/finetune/lora/v6/fla/ops/rotary.py new file mode 100644 index 0000000..18ccc5f --- /dev/null +++ b/finetune/lora/v6/fla/ops/rotary.py @@ -0,0 +1,252 @@ +# Copyright (c) 2023, Tri Dao. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py + +from typing import Optional, Union + +import torch + +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 2}), +# triton.Config({"BLOCK_M": 4}), +# triton.Config({"BLOCK_M": 8}), +# triton.Config({"BLOCK_M": 16}), +# ], +# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], +# ) +@triton.jit +def rotary_kernel( + OUT, # Pointers to matrices + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + BLOCK_K: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * BLOCK_M >= seqlen: + return + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, BLOCK_K) + rk_half = tl.arange(0, BLOCK_K // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT + X = X + (rm[:, None] * stride_x_seqlen + + rk_half[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load( + COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0 + ).to(tl.float32) + sin = tl.load( + SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x0 = tl.load( + X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0 + ).to(tl.float32) + x1 = tl.load( + X + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + OUT = OUT + (rm[:, None] * stride_out_seqlen + + rk_half[None, :] * stride_out_headdim) + tl.store(OUT, o0, mask=(rm[:, None] < seqlen) + & (rk_half[None, :] < rotary_dim_half)) + tl.store( + OUT + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, BLOCK_K) // 2 + X0 = X + (rm[:, None] * stride_x_seqlen + + rk[None, :] * stride_x_headdim) + X1 = X + (rm[:, None] * stride_x_seqlen + + rk_swap[None, :] * stride_x_headdim) + COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + COS, + mask=(rm_cs[:, None] < seqlen_ro) & ( + rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + SIN, + mask=(rm_cs[:, None] < seqlen_ro) & ( + rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to( + tl.float32 + ) + x1 = tl.load( + X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0 + ).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + OUT = OUT + (rm[:, None] * stride_out_seqlen + + rk[None, :] * stride_out_headdim) + tl.store(OUT, out, mask=(rm[:, None] < seqlen) + & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert ( + cos.dtype == sin.dtype + ), f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert ( + x.dtype == cos.dtype + ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + BLOCK_K = ( + 32 + if rotary_dim <= 32 + else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + ) + def grid(META): return (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa + BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + # key for triton cache (limit number of compilations) + seqlen // 128, + # batch_strides if not varlen else 0 + output.stride(0) if not is_varlen else 0, + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + # batch_strides if not varlen else 0 + x.stride(0) if not is_varlen else 0, + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + BLOCK_K, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M, + ) + return output diff --git a/finetune/lora/v6/fla/ops/rwkv4/__init__.py b/finetune/lora/v6/fla/ops/rwkv4/__init__.py new file mode 100644 index 0000000..ae23a00 --- /dev/null +++ b/finetune/lora/v6/fla/ops/rwkv4/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .recurrent_fuse import fused_recurrent_rwkv4 + +__all__ = [ + 'fused_recurrent_rwkv4' +] diff --git a/finetune/lora/v6/fla/ops/rwkv4/recurrent_fuse.py b/finetune/lora/v6/fla/ops/rwkv4/recurrent_fuse.py new file mode 100644 index 0000000..3232087 --- /dev/null +++ b/finetune/lora/v6/fla/ops/rwkv4/recurrent_fuse.py @@ -0,0 +1,484 @@ +# -*- coding: utf-8 -*- +# adopted from https://github.com/codekansas/rwkv + +from typing import Any, cast + +import torch +import triton +import triton.language as tl +from torch import Tensor +from torch.autograd.function import Function, FunctionCtx, once_differentiable + + +def get_block_size_c(chans: int) -> int: + if chans < 32: + return 32 + if chans < 64: + return 64 + return 128 + + +@triton.jit +def fused_recurrent_rwkv4_forward_kernel( + # W + w_ptr, + w_s_c, + # U + u_ptr, + u_s_c, + # K + k_ptr, + k_s_b, + k_s_t, + k_s_c, + # V + v_ptr, + v_s_b, + v_s_t, + v_s_c, + # State + state_ptr, + state_s_b, + state_s_abe, + state_s_c, + # WKV + wkv_ptr, + wkv_s_b, + wkv_s_t, + wkv_s_c, + # Output state + state_out_ptr, + state_out_s_b, + state_out_s_abe, + state_out_s_t, + state_out_s_c, + # Params + chans, + tsz, + BLOCK_SIZE_C: tl.constexpr, +): + # Parallelize over the batch dimension. + b_idx = tl.program_id(0) + c_idx = tl.program_id(1) + + cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) + cmask = cs < chans + + # Pointers to the batch (and possibly channel) for the input tensors. + k_ptr = k_ptr + b_idx * k_s_b + v_ptr = v_ptr + b_idx * v_s_b + alpha_ptr = state_ptr + b_idx * state_s_b + beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe + eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe + + # Pointers to the batch (and possibly channel) for the output tensors. + wkv_ptr = wkv_ptr + b_idx * wkv_s_b + alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b + beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe + eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe + + # Loads parameters. + alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32) + u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32) + + for t in range(tsz): + kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32) + vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32) + + ukt = u + kt + tau = tl.maximum(ukt, eps) + e1a = tl.exp(eps - tau) + e2a = tl.exp(ukt - tau) + wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a) + tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask) + + w_eps = w + eps + eps = tl.maximum(w_eps, kt) + e1b = tl.exp(w_eps - eps) + e2b = tl.exp(kt - eps) + alpha = e1b * alpha + e2b * vt + beta = e1b * beta + e2b + tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask) + tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask) + tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask) + + +def fused_recurrent_rwkv4_forward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, +) -> tuple[Tensor, Tensor]: + (bsz, tsz, chans) = k.shape + + # New tensors to output. + wkvs = k.new_empty(bsz, tsz, chans) + state_out = k.new_empty(bsz, 3, tsz, chans) + + # Constants. + block_size_c = get_block_size_c(chans) + + def grid(meta: dict[str, Any]) -> tuple[int, ...]: + return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) + + fused_recurrent_rwkv4_forward_kernel[grid]( + # W + w, + w.stride(0), + # U + u, + u.stride(0), + # K + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V + v, + v.stride(0), + v.stride(1), + v.stride(2), + # State + state, + state.stride(0), + state.stride(1), + state.stride(3), + # WKV + wkvs, + wkvs.stride(0), + wkvs.stride(1), + wkvs.stride(2), + # Output state + state_out, + state_out.stride(0), + state_out.stride(1), + state_out.stride(2), + state_out.stride(3), + # Params + chans, + tsz, + BLOCK_SIZE_C=block_size_c, + ) + + state_out = torch.cat((state, state_out), dim=2) + + return wkvs, state_out + + +@triton.jit +def fused_recurrent_rwkv4_backward_kernel( + # W + w_ptr, + w_s_c, + # U + u_ptr, + u_s_c, + # K + k_ptr, + k_s_b, + k_s_t, + k_s_c, + # V + v_ptr, + v_s_b, + v_s_t, + v_s_c, + # State + state_ptr, + state_s_b, + state_s_abe, + state_s_t, + state_s_c, + # WKV grad + gwkv_ptr, + gwkv_s_b, + gwkv_s_t, + gwkv_s_c, + # Output state grad + gstate_out_ptr, + gstate_out_s_b, + gstate_out_s_abe, + gstate_out_s_c, + # W grad + gw_ptr, + gw_s_c, + # U grad + gu_ptr, + gu_s_c, + # K grad + gk_ptr, + gk_s_b, + gk_s_t, + gk_s_c, + # V grad + gv_ptr, + gv_s_b, + gv_s_t, + gv_s_c, + # State grad + gstate_ptr, + gstate_s_b, + gstate_s_abe, + gstate_s_c, + # Params + tsz, + chans, + BLOCK_SIZE_C: tl.constexpr, +): + # Parallelize over the batch dimension. + b_idx = tl.program_id(0) + c_idx = tl.program_id(1) + + cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) + cmask = cs < chans + + # Pointers to the batch (and possibly channel) for the input tensors. + k_ptr = k_ptr + b_idx * k_s_b + v_ptr = v_ptr + b_idx * v_s_b + alpha_ptr = state_ptr + b_idx * state_s_b + beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe + eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe + + # Pointers to the batch (and possibly channel) for the output tensors. + gk_ptr = gk_ptr + b_idx * gk_s_b + gv_ptr = gv_ptr + b_idx * gv_s_b + + # Pointers to gradients which were recieved by the function. + gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b + galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe + geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe + + # Loads parameters. + galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32) + u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32) + + # Gradient accumulators. + gw = tl.zeros_like(w) + gu = tl.zeros_like(u) + + alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + + for t in range(tsz): + tc = tsz - t - 1 + + kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32) + vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32) + + alpha_curr = alpha_prev + beta_curr = beta_prev + eps_curr = eps_prev + + alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + + ukt = u + kt + tau = tl.maximum(ukt, eps_prev) + e1 = tl.exp(eps_prev - tau) + e2 = tl.exp(ukt - tau) + + euke = tl.exp(ukt + eps_prev - 2 * tau) + + denom = e1 * beta_prev + e2 + denom_sq = denom * denom + + gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32) + + # Backpropagates wkv gradients. + guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq + gu += guk + gk = guk + gv = gwkvt * e2 / denom + + galpha_wkv = gwkvt * e1 / denom + gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq + geps_wkv_denom = e1 * beta_prev + e2 + geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom) + + e1 = tl.exp(w + eps_prev - eps_curr) + e2 = tl.exp(kt - eps_curr) + + # Backpropagates alpha gradients. + galpha_we = galpha * e1 * alpha_prev + gw += galpha_we + gk += galpha * e2 * vt + gv += galpha * e2 + geps += galpha * -alpha_curr + + # Backpropagates beta gradients. + gbeta_we = gbeta * e1 * beta_prev + gw += gbeta_we + gk += gbeta * e2 + geps += gbeta * -beta_curr + + # Backpropagates epsilon gradients. + geps_mask = w + eps_prev > kt + geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps)) + gw += geps_we + gk += tl.where(geps_mask, tl.zeros_like(geps), geps) + + # Stores the gradients for k and v. + tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask) + tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask) + + # Computes new gradients for alpha and beta. + galpha = galpha * e1 + galpha_wkv + gbeta = gbeta * e1 + gbeta_wkv + geps = galpha_we + gbeta_we + geps_we + geps_wkv + + # Stores final gradients for alpha and beta. + galpha_ptr = gstate_ptr + b_idx * gstate_s_b + gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe + geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe + tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask) + tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask) + tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask) + + # Stores final gradients for w and u. + gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32) + gw_temp += gw + tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask) + gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32) + gu_temp += gu + tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask) + + +def fused_recurrent_rwkv4_backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, + grad_wkv: Tensor, + grad_state: Tensor, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + bsz, tsz, chans = k.shape + + gw = torch.zeros_like(w) # New tensors to output. + gu = torch.zeros_like(u) + gk = torch.empty_like(k) + gv = torch.empty_like(v) + gstate = k.new_empty(bsz, 3, 1, chans) + + block_size_c = get_block_size_c(chans) # Constants. + + def grid(meta: dict[str, Any]) -> tuple[int, ...]: + return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) + + fused_recurrent_rwkv4_backward_kernel[grid]( + # W + w, + w.stride(0), + # U + u, + u.stride(0), + # K + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V + v, + v.stride(0), + v.stride(1), + v.stride(2), + # State + state, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + # WKV grad + grad_wkv, + grad_wkv.stride(0), + grad_wkv.stride(1), + grad_wkv.stride(2), + # Output state grad + grad_state, + grad_state.stride(0), + grad_state.stride(1), + grad_state.stride(3), + # W grad + gw, + gw.stride(0), + # U grad + gu, + gu.stride(0), + # K grad + gk, + gk.stride(0), + gk.stride(1), + gk.stride(2), + # V grad + gv, + gv.stride(0), + gv.stride(1), + gv.stride(2), + # State grad + gstate, + gstate.stride(0), + gstate.stride(1), + gstate.stride(3), + # Params + tsz, + chans, + BLOCK_SIZE_C=block_size_c, + ) + + return gw, gu, gk, gv, gstate + + +class FusedRecurrentRWKV4Function(Function): + @staticmethod + def forward( + ctx: FunctionCtx, + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, + ) -> tuple[Tensor, Tensor]: + ctx.input_dtype = k.dtype + + if ( + w.device.type != "cuda" + or u.device.type != "cuda" + or k.device.type != "cuda" + or v.device.type != "cuda" + ): + raise ValueError( + "Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices." + ) + + w = -torch.exp(w.float().contiguous()) + if k.dtype == torch.float16: + u = u.float() + k = k.float() + v = v.float() + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state) + ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1]) + return wkv, state_out[:, :, -1:] + + @staticmethod + @once_differentiable + def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors) + gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate) + return gw, gu, gk, gv, gstate + + +def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]: + return FusedRecurrentRWKV4Function.apply(w, u, k, v, state) diff --git a/finetune/lora/v6/fla/ops/rwkv6/__init__.py b/finetune/lora/v6/fla/ops/rwkv6/__init__.py new file mode 100644 index 0000000..52f9fe7 --- /dev/null +++ b/finetune/lora/v6/fla/ops/rwkv6/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_rwkv6 +from .recurrent_fuse import fused_recurrent_rwkv6 + +__all__ = [ + 'chunk_rwkv6', + 'fused_recurrent_rwkv6' +] diff --git a/finetune/lora/v6/fla/ops/rwkv6/chunk.py b/finetune/lora/v6/fla/ops/rwkv6/chunk.py new file mode 100644 index 0000000..0e746c9 --- /dev/null +++ b/finetune/lora/v6/fla/ops/rwkv6/chunk.py @@ -0,0 +1,921 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import chunk_reversed_cumsum_fwd +from fla.utils import contiguous + + +@triton.autotune( + configs=[ + triton.Config({'BS': 16}, num_warps=2), + triton.Config({'BS': 16}, num_warps=4), + triton.Config({'BS': 16}, num_warps=8), + triton.Config({'BS': 32}, num_warps=2), + triton.Config({'BS': 32}, num_warps=4), + triton.Config({'BS': 32}, num_warps=8), + triton.Config({'BS': 64}, num_warps=2), + triton.Config({'BS': 64}, num_warps=4), + triton.Config({'BS': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def chunk_rwkv6_fwd_kernel_cum( + s, + o, + o_minus_s, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def post_process_grad( + q, + k, + v, + u, + do, + dk, + dq, + du, + scale, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + H, + T: tl.constexpr, + BT: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_h = i_bh % H + + # Note that BK = tl.next_power_of_2(K), BV = tl.next_power_of_2(V) + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_u = tl.load(p_u, boundary_check=(0,)) + + b_vdo = tl.sum(b_v * b_do, axis=1) + b_du = b_vdo[:, None] * b_k * b_q * scale + b_dq = b_vdo[:, None] * b_k * b_u[None, :] * scale + b_dk = b_vdo[:, None] * b_q * b_u[None, :] * scale + + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + tl.store(p_du, b_du.to(p_du.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_rwkv6_fwd_kernel_h( + k, + v, + g, + h, + h0, + ht, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BK, BT] + b_g = tl.load(p_g, boundary_check=(0, 1)) + if i_t < NT - 1: + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + else: + b_gn = tl.min(b_g, axis=1) + b_h *= tl.exp(b_gn)[:, None] + b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype) + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_rwkv6_fwd_kernel_intra( + q, + k, + g, + gs, + u, + A, + s_k_h, + s_k_t, + s_k_d, + scale, + H, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + DK: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + i_h = i_bh % H + n_bh = tl.num_programs(2) + + o_k = i_k * BK + tl.arange(0, BK) + o_q = i_t * BT + i_i * BC + m_k = o_k < K + + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BK,] + b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_gs - b_gn[None, :]) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_qg, b_kg, allow_tf32=False) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_q_self = tl.make_block_ptr(q + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + o_i = tl.arange(0, BC) + o_g = i_bh * T * K + (i_t * BT + i_j * BC) * K + o_k + o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + p_u = tl.make_block_ptr(u + i_h * DK, (DK,), (1,), (i_k * BK), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + b_gk = tl.load(g + o_g + j * K, mask=(m_k & ((i_t * BT + i_j * BC + j) < T)), other=0).to(tl.float32) + # [BC,] + b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_gs - b_gk[None, :]) * scale, 1) + b_A = tl.where(o_i > j, b_A, 0.) + # self + b_q_self = tl.load(p_q_self, boundary_check=(0,)).to(tl.float32) + A_self = tl.sum(b_q_self * b_k * b_u * scale, axis=0) + m_self = tl.arange(0, BC) == j + b_A = tl.where(m_self, A_self[None], b_A) + tl.store(A + o_A + j, b_A.to(A.dtype.element_ty), mask=m_A) + p_k = tl.advance(p_k, (K,)) + p_q_self = tl.advance(p_q_self, (K,)) + + +@triton.jit +def chunk_rwkv6_fwd_kernel_inter( + q, + v, + gs, + h, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + # [BT, BK] + b_qg = (b_q * tl.exp(b_gs)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_rwkv6_bwd_kernel_dh( + q, + g, + gs, + do, + dh, + dh0, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_dh *= tl.exp(b_gn)[:, None] + # [BK, BT] + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + b_q = (b_q * tl.exp(b_gs)).to(b_q.dtype) + + # [BK, BV] + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_rwkv6_bwd_kernel_inter( + k, + v, + h, + g, + gs, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gq = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gq = tl.load(p_gq, boundary_check=(0, 1)) + b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + + b_dq = b_dq * tl.exp(b_gq) + b_dk = b_dk * b_gn + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] > o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_rwkv6_bwd_kernel_intra( + q, + k, + g, + gs, + dA, + dq, + dk, + s_k_h, + s_k_t, + s_k_d, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + o_k = i_k * BK + tl.arange(0, BK) + o_q = i_t * BT + i_i * BC + m_k = o_k < K + + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BK,] + b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0) + # [BC, BK] + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg, allow_tf32=False) + b_dq *= tl.exp(b_gs - b_gn[None, :]) + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + b_gkj = tl.load(g + i_bh * T * K + (o_q + j) * K + o_k, mask=(m_k & ((o_q + j) < T)), other=0) + # [BC, BK] + m_i = o_i[:, None] > j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_gs - b_gkj[None, :]), 0.) + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_gs - b_gn[None, :])).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False) + b_dk *= tl.exp(b_gn[None, :] - b_gk) + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_gqj = tl.make_block_ptr(gs + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] < j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.) + + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkRWKV6Function(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level): + q = r # alias + B, H, T, K, V = *q.shape, v.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_rwkv6_fwd_kernel_h[grid]( + k, v, g, h, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + final_state = None + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float) + + g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float) + def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H)) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g_org = g_org.view(B, H, NT, BT, -1) + # g = g_org.cumsum(-2).view(B, H, T, -1) + # gs = g - g_org + chunk_rwkv6_fwd_kernel_cum[grid]( + g_org, g, gs, + g.stride(1), g.stride(2), g.stride(3), + T=T, S=K, BT=BT + ) + h = fwd_inner( + q=q, k=k, v=v, g=g, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + h0=initial_state if initial_state is not None else None, + ht=final_state if final_state is not None else None + ) + A = q.new_zeros(NK, B, H, T, BT) + grid = (NK, NT * NC * NC, B * H) + chunk_rwkv6_fwd_kernel_intra[grid]( + q, k, g, gs, u, A, + k.stride(1), k.stride(2), k.stride(3), + scale, + H=H, T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, DK=K, + num_warps=num_warps, + num_stages=num_stages + ) + A = A.sum(0, dtype=A.dtype) + o = torch.empty_like(v) + + grid = (NV, NT, B * H) + chunk_rwkv6_fwd_kernel_inter[grid]( + q, v, gs, h, o, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + if checkpoint_level > 1: + del h + h, initial_state = None, None + del g, gs + ctx.save_for_backward(q, k, v, g_org, u, h, initial_state, A) + ctx.BT = BT + ctx.scale = scale + ctx.checkpoint_level = checkpoint_level + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, dht=None): + q, k, v, g, u, h, initial_state, A = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_rwkv6_fwd_kernel_h[grid]( + k, v, g, h, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + def bwd_inner(q, g, gs, h0, do, B, H, T, K, V, BT, BK, BV, NT, scale): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + dh = q.new_empty(B, H, NT * K, V) + dh0 = torch.empty_like(h0) if h0 is not None else None + grid = (NK, NV, B * H) + chunk_rwkv6_bwd_kernel_dh[grid]( + q, g, gs, do, dh, dh0, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), dh.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=h0 is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return dh, dh0 + + # recompute cumulative log decays. + g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float) + def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H)) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_rwkv6_fwd_kernel_cum[grid]( + g_org, g, gs, + g.stride(1), g.stride(2), g.stride(3), + T=T, S=K, BT=BT + ) + + # rerun the forward pass to get h if checkpoint_level >= 1 + if ctx.checkpoint_level == 1: + h = fwd_inner( + q=q, k=k, v=v, g=g, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + h0=initial_state if initial_state is not None else None, + ht=None + ) + + scale = ctx.scale + dh, dh0 = bwd_inner( + q, g, gs, initial_state, do, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + scale=scale + ) + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + dv = v.new_empty(NK, *v.shape) + dA = q.new_zeros(B, H, T, BT) + grid = (NK, NT, B * H) + chunk_rwkv6_bwd_kernel_inter[grid]( + k, v, h, g, gs, A, do, dh, dq, dk, dv, dA, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0, dtype=dv.dtype) + grid = (NK, NT * NC, B * H) + chunk_rwkv6_bwd_kernel_intra[grid]( + q, k, g, gs, dA, dq, dk, + k.stride(1), k.stride(2), k.stride(3), + T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + + # TODO: fuse? + dg = (dq * q)[:, :, 1:] - (dk * k)[:, :, 0:-1] + dg = torch.nn.functional.pad(dg, (0, 0, 0, 1, 0, 0, 0, 0), value=0) + dg = chunk_reversed_cumsum_fwd(dg).to(g) + # equivalent to the following pytorch code. + # du = ((do * v).sum(-1)[..., None] * k * q * scale).sum(-2).to(u) + # dq += ((do * v).sum(-1)[..., None] * k * scale * u[:, :, None, :]) + # dk += ((do * v).sum(-1)[..., None] * q * scale * u[:, :, None, :]) + BT = 64 + grid = (triton.cdiv(T, BT), B * H) + du = torch.empty_like(g, dtype=torch.float) + post_process_grad[grid]( + q, k, v, u, do, dk, dq, du, scale, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), H=H, + T=T, BT=BT, K=K, V=V, BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V), + num_warps=4 + ) + du = du.sum([0, 2]) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None + + +def chunk_rwkv6( + r: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + u: torch.Tensor, + scale: Optional[int] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + checkpoint_level: Optional[int] = 0 +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + r (torch.Tensor): + reception of shape `(B, H, T, K)`. Alias: q, query in linear attention. + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + w (torch.Tensor): + data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g. + u (torch.Tensor): + bonus of shape `(H, K)` + scale (Optional[int]): + Scale factor for the RWKV6 attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + checkpoint_level (Optional[int]): + Checkpointing level; higher values will save more memories and do more recomputations during backward. + Default: `0`: + - Level `0`: store forward hidden states for backprop. + - Level `1`: recompute the forward hidden states during backward. + """ + assert checkpoint_level in [0, 1] + if scale is None: + scale = r.shape[-1] ** -0.5 + o, final_state = ChunkRWKV6Function.apply(r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level) + return o, final_state + + +if __name__ == "__main__": + import torch.nn.functional as F + + from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6 + B = 4 + H = 4 + L = 1024 + K = 100 + V = 120 + + torch.manual_seed(0) + dtype = torch.float32 + q = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True) + k = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True) + v = torch.randn(B, H, L, V).cuda().to(dtype).requires_grad_(True) + w = (-torch.randn(B, H, L, K).exp()).cuda().to(torch.float32).requires_grad_(True) + u = torch.randn(H, K).cuda().to(dtype).requires_grad_(True) + h0 = torch.randn(B, H, K, V).cuda().to(dtype).requires_grad_(True) + do = torch.rand_like(v).cuda() + o, ht = fused_recurrent_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True) + o.backward(do) + dq, q.grad = q.grad.clone(), None + dk, k.grad = k.grad.clone(), None + dv, v.grad = v.grad.clone(), None + dw, w.grad = w.grad.clone(), None + du, u.grad = u.grad.clone(), None + dh0, h0.grad = h0.grad.clone(), None + o2, ht2 = chunk_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True) + o2.backward(do) + torch.testing.assert_close(o, o2, rtol=0, atol=1e-4) + torch.testing.assert_close(ht, ht2, rtol=0, atol=1e-4) + torch.testing.assert_close(q.grad, dq, rtol=0, atol=1e-4) + torch.testing.assert_close(k.grad, dk, rtol=0, atol=1e-4) + torch.testing.assert_close(v.grad, dv, rtol=0, atol=1e-4) + torch.testing.assert_close(w.grad, dw, rtol=0, atol=1e-4) + torch.testing.assert_close(u.grad, du, rtol=0, atol=2e-4) + torch.testing.assert_close(h0.grad, dh0, rtol=0, atol=2e-4) + + print("All tests passed!") + + @triton.testing.perf_report( + triton.testing.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['T'], + # different possible values for `x_name` + x_vals=[128 * 2 ** i for i in range(0, 8)], + # argument name whose value corresponds to a different line in the plot + line_arg='provider', + # possible values for `line_arg`` + line_vals=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'], + # label name for the lines + line_names=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'], + # line styles + styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')], + ylabel="Execution Time (ms)", # label name for the y-axis + # name for the plot. Used also as a file name for saving the plot. + plot_name="Performance", + args={}, + ) + ) + def benchmark(T, provider): + device = 'cuda' + dtype = torch.bfloat16 + requires_grad = True + B, H, K = 16, 4, 128 + + q = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype) + k = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype) + v = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype) + w = F.logsigmoid(torch.randn(B, H, T, K)).to(dtype=dtype, device=device).requires_grad_(True) + u = torch.randn(H, K, device=device, requires_grad=requires_grad, dtype=dtype) + + do = torch.ones_like(q, dtype=dtype) + quantiles = [0.5, 0.2, 0.8] + results = 0, 0, 0 + if provider == 'recurrent': + results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u), quantiles=quantiles) + if provider == 'chunk': + results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u), quantiles=quantiles) + if provider == 'recurrent_bwd': + results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u) + [0].backward(do), quantiles=quantiles) + if provider == 'chunk_bwd': + results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u)[0].backward(do), quantiles=quantiles) + return results + benchmark.run(print_data=True) diff --git a/finetune/lora/v6/fla/ops/rwkv6/chunk_naive.py b/finetune/lora/v6/fla/ops/rwkv6/chunk_naive.py new file mode 100644 index 0000000..e4cf9bd --- /dev/null +++ b/finetune/lora/v6/fla/ops/rwkv6/chunk_naive.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + +from fla.ops.rwkv6.chunk import chunk_rwkv6 +from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6 + + +def naive_chunk_rwkv6( + q, + k, + v, + w, + u, + chunk_size=32, + initial_state=None, + output_final_state=True, +): + assert q.shape[-2] % chunk_size == 0 + orig_dtype = q.dtype + num_chunk = q.shape[-2] // chunk_size + u = u.unsqueeze(0) + + q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w)) + + w_cumsum = w.cumsum(-2) + + kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp() + wkv = kw.transpose(-1, -2) @ v + + wkv_new = torch.zeros_like(wkv) + + for i in range(num_chunk - 1): + wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i] + + o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp())) + + o_intra = torch.zeros_like(o_inter) + for i in range(chunk_size): + attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1) + mask = (torch.arange(0, chunk_size) < i).to(attn.device) + attn.masked_fill_(~mask, 0) + intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2) + intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i] + o_intra[:, :, :, i] = intra_inter_o + intra_intra_o + o = o_inter + o_intra + return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype) + + +if __name__ == "__main__": + B = 4 + H = 4 + L = 1024 + D = 100 + dtype = torch.bfloat16 + require_grad = True + q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad) + k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad) + v = torch.randn(B, H, L, 2*D).cuda().to(dtype).requires_grad_(require_grad) + w = torch.nn.functional.logsigmoid(torch.randn(B, H, L, D)).cuda().to(dtype).requires_grad_(require_grad) + u = (torch.randn(H, D).cuda().to(dtype)).requires_grad_(require_grad) + do = torch.rand_like(v).cuda() + o2, _ = chunk_rwkv6(q, k, v, w.clone(), u) + o, _ = fused_recurrent_rwkv6(q, k, v, w, u, scale=1.0) + o.backward(do) + dq, q.grad = q.grad.clone(), None + dk, k.grad = k.grad.clone(), None + dv, v.grad = v.grad.clone(), None + dw, w.grad = w.grad.clone(), None + du, u.grad = u.grad.clone(), None + print((o - o2).abs().max()) + o2.backward(do) + print((o-o2).abs().max()) + print((q.grad - dq).abs().max()) + print((k.grad - dk).abs().max()) + print((v.grad - dv).abs().max()) + print((w.grad - dw).abs().max()) + print((u.grad - du).abs().max()) diff --git a/finetune/lora/v6/fla/ops/rwkv6/recurrent_fuse.py b/finetune/lora/v6/fla/ops/rwkv6/recurrent_fuse.py new file mode 100644 index 0000000..af25152 --- /dev/null +++ b/finetune/lora/v6/fla/ops/rwkv6/recurrent_fuse.py @@ -0,0 +1,378 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2024, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.ops.utils import chunk_reversed_cumsum_fwd +from fla.utils import contiguous + + +@triton.jit +def fused_recurrent_rwkv6_fwd_kernel( + q, # query [B, H, T, K] + k, # key [B, H, T, K] + v, # value [B, H, T, V] + w, # log gate [B, H, T, K] + u, # bonus [B, H, K] + o, # output [B, H, T, V] + # initial hidden state initialization [B, H, K, V] + h0, + ht, # final hidden state [B, H, K, V] + s_k_h, # stride size: T * K + s_v_h, # stride size: T * V + scale, # K ** -0.5 + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bv[:, None] & mask_bk[None, :] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32) + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32) + b_w = tl.exp(b_w) + b_kv = b_k[None, :] * b_v[:, None] + b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + b_h = b_h * b_w[None, :] + b_h += b_kv + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv) + p_q += -K if REVERSE else K + p_k += -K if REVERSE else K + p_o += -V if REVERSE else V + p_v += -V if REVERSE else V + p_w += -K if REVERSE else K + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_rwkv6_bwd_kernel_dq( + # B: B, H: H, T: T, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + k, # key [B, H, T, V] + v, # value [B, H, T, V] + w, # log gate [B, H, T, K] + u, # bonus [B, H, K] + + do, # gradient of output [B, H, T, V] + dq, # gradient of query [NV, B, H, T, K] + dq_aux, # gradient of query_aux [NV, B, H, T, K] + + # initial hidden state initialization [B, H, K, V] + h0, + + s_k_h, # stride size: T * K + s_v_h, # stride size: T * V + + scale, # K ** -0.5 + B: tl.constexpr, # B + H: tl.constexpr, # H + T: tl.constexpr, # T + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + K: tl.constexpr, # K + V: tl.constexpr, # V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_dq_aux = dq_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK + + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + mask_kv = mask_bv[:, None] & mask_bk[None, :] + b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32) + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_kv = b_k[None, :] * b_v[:, None] + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32) + b_w = tl.exp(b_w) + h_q = b_h * b_do[:, None] + b_dq = tl.sum(h_q + b_kv * b_u[None, :] * b_do[:, None], axis=0) + b_dq *= scale + b_dq_aux = tl.sum(h_q, axis=0) + b_h = b_h * b_w[None, :] + b_h += b_kv + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk) + tl.store(p_dq_aux, b_dq_aux.to(p_dq_aux.dtype.element_ty), mask=mask_bk) + p_k += -K if REVERSE else K + p_do += -V if REVERSE else V + p_v += -V if REVERSE else V + p_w += -K if REVERSE else K + p_dq += -K if REVERSE else K + p_dq_aux += -K if REVERSE else K + + +@triton.jit +def fused_recurrent_rwkv6_bwd_kernel_dkv( + # B: B, H: H, T: T, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, T, K] + k, # key [B, H, T, V] + v, # value [B, H, T, V] + w, # log gate [B, H, T, K] + u, # bonus [B, H, K] + + do, # gradient of output [B, H, T, V] + dk, + dk_aux, + dv, + dh0, + + # initial hidden state initialization [B, H, K, V] + s_k_h, # stride size: T * K + s_v_h, # stride size: T * V + + scale, # K ** -0.5 + B, # B + H, # H + T, # T + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + K: tl.constexpr, # K + V: tl.constexpr, # V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_dk_aux = dk_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + mask_kv = mask_bk[:, None] & mask_bv[None, :] + + p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK + b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32) + + for _ in range(T-1, -1, -1): + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + b_dkv = b_q[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * b_v[None, :], axis=1) + tl.store(p_dk_aux, b_dk.to(p_dk_aux.dtype.element_ty), mask=mask_bk) + b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], axis=1) + b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], axis=0) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv) + b_dh *= tl.exp(b_w)[:, None] + b_dh += b_dkv + + p_q += K if REVERSE else -K + p_k += K if REVERSE else -K + p_v += V if REVERSE else -V + p_w += K if REVERSE else -K + p_do += V if REVERSE else -V + p_dk += K if REVERSE else -K + p_dk_aux += K if REVERSE else -K + p_dv += V if REVERSE else -V + + if USE_INITIAL_STATE: + p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_kv) + + +class FusedRecurrentRWKV6Function(torch.autograd.Function): + + @staticmethod + @contiguous + @custom_fwd + def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False): + # alias + q = r + B, H, T, K, V = *q.shape, v.shape[-1] + + BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 1 + + if output_final_state: + final_state = q.new_empty(B, H, K, V) + else: + final_state = None + + o = q.new_empty(NK, B, H, T, V, dtype=torch.float32) + grid = (NV, NK, B * H) + fused_recurrent_rwkv6_fwd_kernel[grid]( + q, k, v, w, u, o, initial_state, final_state, + k.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + REVERSE=reverse, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, w, u, initial_state, o) + ctx.scale = scale + ctx.reverse = reverse + # we do not need the gradient of the final state from the next chunk + # similiar to Trunctated BPTT + if final_state is not None: + final_state = final_state.detach() + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @custom_bwd + def backward(ctx, do, d_final_state=None): + q, k, v, w, u, initial_state, o = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + scale = ctx.scale + + BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 1 + dq = q.new_empty(NV, B, H, T, K, dtype=torch.float32) + dq_aux = torch.empty_like(dq) + grid = (NV, NK, B * H) + + fused_recurrent_rwkv6_bwd_kernel_dq[grid]( + k, v, w, u, do, dq, dq_aux, initial_state, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None, + REVERSE=ctx.reverse, + ) + dq = dq.sum(0).to(q) + dq_aux = dq_aux.sum(0) + + BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dk = q.new_empty(NV, B, H, T, K, dtype=torch.float32) + dk_aux = q.new_empty(NV, B, H, T, K, dtype=torch.float32) + dv = q.new_empty(NK, B, H, T, V, dtype=torch.float32) + dh0 = initial_state.new_empty(B, H, K, V) if initial_state is not None else None + grid = (NV, NK, B * H) + fused_recurrent_rwkv6_bwd_kernel_dkv[grid]( + q, k, v, w, u, do, dk, dk_aux, dv, dh0, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None, + REVERSE=ctx.reverse, + ) + dk = dk.sum(0).to(k) + dv = dv.sum(0).to(v) + dk_aux = dk_aux.sum(0) + + dw = (dq_aux * q * scale)[:, :, 1:] - (dk_aux * k)[:, :, 0:-1] + dw = torch.nn.functional.pad(dw, (0, 0, 0, 1, 0, 0, 0, 0), value=0) + dw = chunk_reversed_cumsum_fwd(dw).to(w) + + du = ((do * v).sum(-1)[..., None] * k * q * scale).sum([0, -2]).to(u) + return dq, dk, dv, dw, du, None, dh0, None, None + + +def fused_recurrent_rwkv6( + r: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: int = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + causal: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + r (torch.Tensor): + reception of shape `(B, H, T, K)`. Alias: q, query in linear attention. + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + w (torch.Tensor): + data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g. + u (torch.Tensor): + bonus of shape `(H, K)` + scale (Optional[int]): + Scale factor for the RWKV6 attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + """ + if scale == -1: + scale = r.shape[-1] ** -0.5 + o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py b/finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py new file mode 100644 index 0000000..7b1b67e --- /dev/null +++ b/finetune/lora/v6/fla/ops/rwkv6/recurrent_naive.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +): + orig_dtype = q.dtype + B, H, T, K, V = *q.shape, v.shape[-1] + q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u)) + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + o = torch.zeros_like(v) + + if scale is None: + scale = K ** -0.5 + + if initial_state is not None: + h += initial_state + + for i in range(T): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + w_i = w[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None] + o[:, :, i] = o_i.sum(-2) + h = h * w_i[..., None] + kv_i + ht = h if output_final_state else None + return o.to(orig_dtype), ht + + +def naive_recurrent_rwkv6_bwd( + q, + k, + v, + w, + u, + o, + do, + initial_state=None, + output_final_state=False +): + q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do)) + B, H, T, K, V = *q.shape, v.shape[-1] + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + dq = torch.zeros_like(q) + dq_aux = torch.zeros_like(q) + + if initial_state is not None: + h += initial_state + + for i in range(T): + k_i = k[:, :, i] + v_i = v[:, :, i] + w_i = w[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h_i = (h + u[None, ..., None] * kv_i) + dq_i = (do[:, :, i, None, :] * h_i).sum(-1) + dq_aux_i = (do[:, :, i, None, :] * h).sum(-1) + dq[:, :, i] = dq_i + dq_aux[:, :, i] = dq_aux_i + h = h * w_i[..., None] + kv_i + + du = torch.zeros_like(u) + dh = torch.zeros_like(h) + dk = torch.zeros_like(k) + dk_aux = torch.zeros_like(k) + dv = torch.zeros_like(v) + + for i in range(T - 1, -1, -1): + d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None] + k_i = k[:, :, i] + v_i = v[:, :, i] + du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1) + du += du_i + dk_i = (dh * v_i[..., None, :]).sum(-1) + dk_aux[:, :, i] = dk_i + dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1) + dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2) + dv_i += (dh * k_i[..., None]).sum(-2) + + dk[:, :, i] = dk_i + dv[:, :, i] = dv_i + dh = dh * w[:, :, i, :, None].exp() + d_kv_i + + # dw = q * dq_aux - k * dk_aux + dw = torch.zeros_like(w) + for i in range(T - 2, -1, -1): + dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i] + + return dq, dk, dv, dw, du diff --git a/finetune/lora/v6/fla/ops/simple_gla/README.md b/finetune/lora/v6/fla/ops/simple_gla/README.md new file mode 100644 index 0000000..72e710a --- /dev/null +++ b/finetune/lora/v6/fla/ops/simple_gla/README.md @@ -0,0 +1,5 @@ +- Simple GLA + +Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA. + +$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. \ No newline at end of file diff --git a/finetune/lora/v6/fla/ops/simple_gla/__init__.py b/finetune/lora/v6/fla/ops/simple_gla/__init__.py new file mode 100644 index 0000000..b2f9066 --- /dev/null +++ b/finetune/lora/v6/fla/ops/simple_gla/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_simple_gla + +__all__ = [ + 'chunk_simple_gla' +] + diff --git a/finetune/lora/v6/fla/ops/simple_gla/chunk.py b/finetune/lora/v6/fla/ops/simple_gla/chunk.py new file mode 100644 index 0000000..a5ca3d4 --- /dev/null +++ b/finetune/lora/v6/fla/ops/simple_gla/chunk.py @@ -0,0 +1,415 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from torch.cuda.amp import custom_bwd, custom_fwd + +from fla.utils import contiguous + + +@torch.jit.script +def normalize_output(q, k, o): + k = k.transpose(-2, -1) + k = k.cumsum(-1) + k = k.transpose(-2, -1) + z = (q * k).sum(-1, keepdim=True) + return o / (z + 1e-5) + + +@triton.jit +def chunk_simple_gla_fwd_kernel_h( + k, + v, + h, + g, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, + (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_k = tl.make_block_ptr( + k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr( + v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, + (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BK, BV] + b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1) + b_h *= tl.math.exp2(b_g_last) + b_g = tl.load(g + i_bh * T + i_t * BT + tl.arange(0, BT)) + b_h += tl.dot(b_k, (b_v * tl.math.exp2(b_g_last - b_g)[:, None]).to(b_k.dtype), allow_tf32=False) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr( + final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_simple_gla_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr( + k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, + (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT] + + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT) + b_g = tl.load(p_g) + b_o = b_o * tl.math.exp2(b_g)[:, None] + b_s = b_s * tl.math.exp2(b_g[:, None] - b_g[None, :]) + b_s = tl.where(m_s, b_s, 0) + + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), + (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale + p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), + (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_simple_gla_bwd_kernel_dh( + q, + g, + do, + dh, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, + (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale * tl.math.exp2(tl.load(g + i_bh * T + + i_t * BT + tl.arange(0, BT)))[None, :]).to(b_q.dtype) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh *= tl.math.exp2(tl.load(g + i_bh * T + i_t * BT + BT - 1)) + b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + + +@triton.jit +def chunk_simple_gla_bwd_kernel_dqkv( + q, + k, + v, + h, + g, + do, + dh, + dq, + dk, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + o_i = tl.arange(0, BT) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), + (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_s = tl.dot(b_k, b_q, allow_tf32=False) + p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT) + b_g = tl.load(p_g) + b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1) + mask = tl.math.exp2(b_g[None, :] - b_g[:, None]) + mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0) + b_s = b_s * mask + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), + (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), + (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), + (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.math.exp2(-b_g + b_g_last)[:, None] + \ + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = b_dq * tl.math.exp2(b_g)[:, None] + b_dk = b_dk * tl.math.exp2(-b_g + b_g_last)[:, None] + b_ds = b_ds * tl.trans(mask) + b_ds = b_ds.to(b_k.dtype) + # [BT, BK] + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +class SimpleGLAFunction(torch.autograd.Function): + + @staticmethod + @custom_fwd + @contiguous + def forward(ctx, q, k, v, g, initial_state, output_final_state): + B, H, T, K, V = *q.shape, v.shape[-1] + BT = 64 + BK, BV = min(64, triton.next_power_of_2(K)), min( + 64, triton.next_power_of_2(V)) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + scale = K ** -0.5 + + BT = 64 + assert T % BT == 0, 'sequence length must be divisible by BT' + g = g.reshape(B, H, -1, BT) + g = g.cumsum(-1) * 1.44269504 + g = g.reshape(B, H, -1) + + final_state = None + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False) + + h = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_simple_gla_fwd_kernel_h[grid]( + k, v, h, g, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NV, NT, B * H) + o = torch.empty_like(v) + chunk_simple_gla_fwd_kernel_o[grid]( + q, k, v, h, g, o, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + ctx.save_for_backward(q, k, v, h, g) + return o.to(q.dtype), final_state + + @staticmethod + @custom_bwd + @contiguous + def backward(ctx, do, d_ht=None): + q, k, v, h, g = ctx.saved_tensors + + B, H, T, K, V = *q.shape, v.shape[-1] + BT = 64 + BK, BV = min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(K)), min( + 32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V)) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + scale = K ** -0.5 + + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_simple_gla_bwd_kernel_dh[grid]( + q, g, do, dh, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NK, NT, B * H) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = v.new_empty(NK, *v.shape) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + chunk_simple_gla_bwd_kernel_dqkv[grid]( + q, k, v, h, g, do, dh, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0) + dg = (dq * q - dk * k).sum(-1) + + def rev_cumsum(x): + cumsum_x = x.cumsum(-1) + rev_cumsum_x = cumsum_x[..., -1, None] - cumsum_x + return rev_cumsum_x + x + dg = rev_cumsum(dg) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, None + + +def chunk_simple_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, # log decay + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + g = g.float() + o, final_state = SimpleGLAFunction.apply(q, k, v, g, initial_state, output_final_state) + return o, final_state diff --git a/finetune/lora/v6/fla/ops/simple_gla/naive.py b/finetune/lora/v6/fla/ops/simple_gla/naive.py new file mode 100644 index 0000000..f7f1e22 --- /dev/null +++ b/finetune/lora/v6/fla/ops/simple_gla/naive.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def torch_simple_gla(q, k, v, g, chunk_size=64): + q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5) + k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) + g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size) + g = g.cumsum(-1) + kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None]) + S = torch.zeros_like(kv) + + for i in range(1, g.shape[-2]): + S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1] + + inter = (q * g[..., None].exp()) @ S + attn = q @ k.transpose(-1, -2) + attn = attn * (g[..., None] - g[..., None, :]).exp() + attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + intra = attn @ v + o = inter + intra + return rearrange(o, 'b h n c d -> b h (n c) d') + + +def torch_simple_gla_recurrent(q, k, v, g, chunk_size=64): + # q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5) + # k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) + # v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) + # g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size) + # g = g.cumsum(-1) + # kv = k.transpose(-1, -2) @ v + + B, H, T, DK = q.shape + q = q * (DK ** -0.5) + _, _, _, DV = v.shape + S = torch.zeros(B, H, DK, DV).to(q) + o = torch.zeros(B, H, T, DV).to(q) + for i in range(T): + gate = g[:, :, i].exp() + key = k[:, :, i] + value = v[:, :, i] + kv = key.unsqueeze(-1) * value.unsqueeze(-2) + S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv + q_i = q[:, :, i, :] + o_i = (q_i.unsqueeze(-1) * S).sum(-2) + o[:, :, i] = o_i + + return o + diff --git a/finetune/lora/v6/fla/ops/utils.py b/finetune/lora/v6/fla/ops/utils.py new file mode 100644 index 0000000..5020209 --- /dev/null +++ b/finetune/lora/v6/fla/ops/utils.py @@ -0,0 +1,579 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import contiguous + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def logcumsumexp_fwd_kernel( + s, + z, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr +): + i_bh = tl.program_id(0) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + b_mp = tl.full([S,], float('-inf'), dtype=tl.float32) + b_zp = tl.zeros([S,], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT)): + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + + # [BT, S] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + # [S,] + b_mc = tl.max(b_s, 0) + # workaround for compiler bugs + if i_t > 0: + b_mc = tl.maximum(b_mp, b_mc) + b_zp = b_zp * tl.exp(b_mp - b_mc) + # [BT, S] + b_s = tl.exp(b_s - b_mc) + b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp + # [S,] + b_zc = tl.max(b_z, 0) + b_mp = b_mc + b_zp = b_zc + # [BT, BS] + # small eps to prevent underflows + b_z = tl.log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc + tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def softmax_fwd_kernel( + s, + p, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + + # [BT, S] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + # [BT] + b_m = tl.max(b_s, 1) + + # [BT, BS] + b_s = tl.exp(b_s - b_m[:, None]) + b_z = tl.sum(b_s, 1) + b_p = tl.where(b_s != 0, b_s / b_z[:, None], 0.) + tl.store(p_p, b_p.to(p_p.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def softmax_bwd_kernel( + p, + dp, + ds, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_p = tl.make_block_ptr(p + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + p_dp = tl.make_block_ptr(dp + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, 0), (BT, S), (1, 0)) + # [BT, BS] + b_p = tl.load(p_p, boundary_check=(0, 1)).to(tl.float32) + b_dp = tl.load(p_dp, boundary_check=(0, 1)).to(tl.float32) + # [BT,] + b_pp = tl.sum(b_p * b_dp, 1) + # [BT, BS] + b_ds = b_p * b_dp - b_p * b_pp[:, None] + tl.store(p_ds, b_ds.to(p_ds.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({'BS': 32}, num_warps=2), + triton.Config({'BS': 32}, num_warps=4), + triton.Config({'BS': 32}, num_warps=8), + triton.Config({'BS': 64}, num_warps=2), + triton.Config({'BS': 64}, num_warps=4), + triton.Config({'BS': 64}, num_warps=8), + triton.Config({'BS': 128}, num_warps=2), + triton.Config({'BS': 128}, num_warps=4), + triton.Config({'BS': 128}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def recurrent_cumsum_fwd_kernel( + s, + z, + s_s_h, + s_s_t, + T: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + + o_s = i_s * BS + tl.arange(0, BS) + mask = o_s < S + + b_z = tl.zeros([BS], dtype=tl.float32) + for i_t in range(0, T): + # [BS] + b_s = tl.load(s + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32) + b_z = b_z + b_s + + tl.store(z + i_bh * s_s_h + i_t * s_s_t + o_s, b_z.to(s.dtype.element_ty), mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({'BS': 32}, num_warps=2), + triton.Config({'BS': 32}, num_warps=4), + triton.Config({'BS': 32}, num_warps=8), + triton.Config({'BS': 64}, num_warps=2), + triton.Config({'BS': 64}, num_warps=4), + triton.Config({'BS': 64}, num_warps=8), + triton.Config({'BS': 128}, num_warps=2), + triton.Config({'BS': 128}, num_warps=4), + triton.Config({'BS': 128}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def recurrent_cumsum_bwd_kernel( + ds, + dz, + s_s_h, + s_s_t, + T: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + + o_s = i_s * BS + tl.arange(0, BS) + mask = o_s < S + + b_ds = tl.zeros([BS], dtype=tl.float32) + for i_t in range(T - 1, -1, -1): + # [BS] + b_dz = tl.load(dz + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32) + b_ds = b_ds + b_dz + + tl.store(ds + i_bh * s_s_h + i_t * s_s_t + o_s, b_ds.to(ds.dtype.element_ty), mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def chunk_cumsum_fwd_kernel( + s, + z, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + b_z = tl.zeros([BS], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT)): + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1)) + + if i_t >= 0: + b_z += tl.sum(b_s, 0) + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def chunk_cumsum_bwd_kernel( + ds, + dz, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + + b_ds = tl.zeros([BS], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): + p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_dz = tl.make_block_ptr(dz + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_dz = tl.load(p_dz, boundary_check=(0, 1)).to(tl.float32) + b_c = b_ds[None, :] + tl.dot(m_s, b_dz, allow_tf32=False) + tl.store(p_ds, b_c.to(p_ds.dtype.element_ty), boundary_check=(0, 1)) + + if i_t >= 0: + b_ds += tl.sum(b_dz, 0) + + +@contiguous +def chunk_cumsum_fwd( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + B, H, T, S = s.shape + BS = 32 + + dtype = dtype or s.dtype + grid = (triton.cdiv(S, BS), B * H) + z = torch.empty_like(s, dtype=dtype) + chunk_cumsum_fwd_kernel[grid]( + s, z, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=S, BS=BS + ) + return z + + +@contiguous +def chunk_cumsum_bwd( + dz: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + B, H, T, S = dz.shape + BS = 32 + + dtype = dtype or dz.dtype + grid = (triton.cdiv(S, BS), B * H) + ds = torch.empty_like(dz, dtype=dtype) + chunk_cumsum_bwd_kernel[grid]( + ds, dz, + ds.stride(1), ds.stride(2), ds.stride(3), + T=T, S=S, BS=BS + ) + return ds + + +class CumsumFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, s, dtype): + z = chunk_cumsum_fwd(s, dtype) + ctx.dtype = dtype + return z + + @staticmethod + def backward(ctx, dz): + ds = chunk_cumsum_bwd(dz, ctx.dtype) + return ds, None + + +def cumsum( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + return CumsumFunction.apply(s, dtype) + + +@triton.autotune( + configs=[ + triton.Config({'BS': 32}, num_warps=2), + triton.Config({'BS': 32}, num_warps=4), + triton.Config({'BS': 32}, num_warps=8), + triton.Config({'BS': 64}, num_warps=2), + triton.Config({'BS': 64}, num_warps=4), + triton.Config({'BS': 64}, num_warps=8), + triton.Config({'BS': 128}, num_warps=2), + triton.Config({'BS': 128}, num_warps=4), + triton.Config({'BS': 128}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def recurrent_reversed_cumsum_fwd_kernel( + s, + z, + s_s_h, + s_s_t, + T: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + + o_s = i_s * BS + tl.arange(0, BS) + mask = o_s < S + + b_z = tl.zeros([BS], dtype=tl.float32) + for i_t in range(T - 1, -1, -1): + # [BS] + b_s = tl.load(s + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32) + b_z = b_z + b_s + + tl.store(z + i_bh * s_s_h + i_t * s_s_t + o_s, b_z.to(s.dtype.element_ty), mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({'BS': 32}, num_warps=2), + triton.Config({'BS': 32}, num_warps=4), + triton.Config({'BS': 32}, num_warps=8), + triton.Config({'BS': 64}, num_warps=2), + triton.Config({'BS': 64}, num_warps=4), + triton.Config({'BS': 64}, num_warps=8), + triton.Config({'BS': 128}, num_warps=2), + triton.Config({'BS': 128}, num_warps=4), + triton.Config({'BS': 128}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def recurrent_reversed_cumsum_bwd_kernel( + ds, + dz, + s_s_h, + s_s_t, + T: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + + o_s = i_s * BS + tl.arange(0, BS) + mask = o_s < S + + b_ds = tl.zeros([BS], dtype=tl.float32) + for i_t in range(0, T): + # [BS] + b_dz = tl.load(dz + i_bh * s_s_h + i_t * s_s_t + o_s, mask=mask, other=0).to(tl.float32) + b_ds = b_ds + b_dz + + tl.store(ds + i_bh * s_s_h + i_t * s_s_t + o_s, b_ds.to(ds.dtype.element_ty), mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def chunk_reversed_cumsum_fwd_kernel( + s, + z, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + + b_z = tl.zeros([BS], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1)) + + if i_t >= 0: + b_z += tl.sum(b_s, 0) + + +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 16}, num_warps=4), + triton.Config({'BT': 16}, num_warps=8), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=8), + triton.Config({'BT': 64}, num_warps=2), + triton.Config({'BT': 64}, num_warps=4), + triton.Config({'BT': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def chunk_reversed_cumsum_bwd_kernel( + ds, + dz, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + b_ds = tl.zeros([BS], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT)): + p_ds = tl.make_block_ptr(ds + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_dz = tl.make_block_ptr(dz + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_dz = tl.load(p_dz, boundary_check=(0, 1)).to(tl.float32) + b_c = b_ds[None, :] + tl.dot(m_s, b_dz, allow_tf32=False) + tl.store(p_ds, b_c.to(p_ds.dtype.element_ty), boundary_check=(0, 1)) + + if i_t >= 0: + b_ds += tl.sum(b_dz, 0) + + +@contiguous +def chunk_reversed_cumsum_fwd( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + B, H, T, S = s.shape + BS = 32 + + dtype = dtype or s.dtype + grid = (triton.cdiv(S, BS), B * H) + z = torch.empty_like(s, dtype=dtype) + chunk_reversed_cumsum_fwd_kernel[grid]( + s, z, + s.stride(1), s.stride(2), s.stride(3), + T=T, S=S, BS=BS + ) + return z + + +@contiguous +def chunk_reversed_cumsum_bwd( + dz: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + B, H, T, S = dz.shape + BS = 32 + + dtype = dtype or dz.dtype + grid = (triton.cdiv(S, BS), B * H) + ds = torch.empty_like(dz, dtype=dtype) + chunk_reversed_cumsum_bwd_kernel[grid]( + ds, dz, + ds.stride(1), ds.stride(2), ds.stride(3), + T=T, S=S, BS=BS + ) + return ds + + +class ReversedCumsumFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, s, dtype): + z = chunk_reversed_cumsum_fwd(s, dtype) + ctx.dtype = dtype + return z + + @staticmethod + def backward(ctx, dz): + ds = chunk_reversed_cumsum_bwd(dz, ctx.dtype) + return ds, None + + +def reversed_cumsum( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + return CumsumFunction.apply(s, dtype) diff --git a/finetune/lora/v6/fla/utils.py b/finetune/lora/v6/fla/utils.py new file mode 100644 index 0000000..7a9a047 --- /dev/null +++ b/finetune/lora/v6/fla/utils.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +import functools + +import torch + + +def contiguous(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + return fn(ctx, + *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), + **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) + return wrapper + + +def require_version(version, hint): + def decorator(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + from transformers.utils.versions import require_version + require_version(version, hint) + return fn(ctx, + *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), + **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) + return wrapper + return decorator + + +def checkpoint(func): + def wrapper(*args, **kwargs): + return torch.utils.checkpoint.checkpoint(func, *args, **kwargs) + return wrapper diff --git a/finetune/lora/v6/merge/merge.py b/finetune/lora/v6/merge/merge.py new file mode 100644 index 0000000..3f97cdf --- /dev/null +++ b/finetune/lora/v6/merge/merge.py @@ -0,0 +1,92 @@ +from collections import OrderedDict +import os +import sys +from typing import Dict +import typing +import torch +import bitsandbytes as bnb +from argparse import ArgumentParser + +parser = ArgumentParser() +parser.add_argument("--type", default="pissa", type=str) +parser.add_argument("--base_model", default="", type=str) +parser.add_argument("--lora_init", default="none", type=str) +parser.add_argument("--lora_checkpoint", default="", type=str) +parser.add_argument("--output", default="", type=str) +parser.add_argument("--quant", default="none", type=str) +parser.add_argument("--device", default="cuda", type=str) +parser.add_argument("--lora_alpha", default=16, type=int) +args = parser.parse_args() +device= args.device +base_model = args.base_model +init_lora= args.lora_init +lora= args.lora_checkpoint +output= args.output +quant= args.quant +lora_alpha = args.lora_alpha + +with torch.no_grad(): + w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') + # merge LoRA-only slim checkpoint into the main weights + w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') + + if args.type=='pissa': + w_init_lora: Dict[str, torch.Tensor] = torch.load(init_lora, map_location='cpu') + for k in w_lora.keys(): + w[k] = w_lora[k] + output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() + # merge LoRA weights + keys = list(w.keys()) + for k in keys: + if k.endswith('.weight'): + prefix = k[:-len('.weight')] + lora_A = prefix + '.lora_A' + lora_B = prefix + '.lora_B' + init_lora_A = prefix + '.init_lora_A' + init_lora_B = prefix + '.init_lora_B' + if lora_A in keys: + assert lora_B in keys + print(f'merging {lora_A} and {lora_B} into {k}') + assert w[lora_B].shape[1] == w[lora_A].shape[0] + lora_r = w[lora_B].shape[1] + w[k] = w[k].to(device=device) + w[lora_A] = w[lora_A].to(device=device) + w[lora_B] = w[lora_B].to(device=device) + + if args.type=='pissa': + w_init_lora[init_lora_A] = w_init_lora[init_lora_A].to(device=device) + w_init_lora[init_lora_B] = w_init_lora[init_lora_B].to(device=device) + if quant=='4bit': + qw,qs = bnb.functional.quantize_4bit(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]) + w[k] = (bnb.functional.dequantize_4bit(qw,quant_state=qs)).to(dtype=torch.bfloat16) + elif quant == 'nf4': + qw,qs = bnb.functional.quantize_nf4(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]) + w[k] = (bnb.functional.dequantize_nf4(qw,quant_state=qs)).to(dtype=torch.bfloat16) + elif quant == 'fp4': + qw,qs = bnb.functional.quantize_fp4(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]) + w[k] = (bnb.functional.dequantize_fp4(qw,quant_state=qs)).to(dtype=torch.bfloat16) + else: + w[k] = (w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]).to(dtype=torch.bfloat16) + w[k] += w[lora_B] @ w[lora_A] + else: + if quant=='4bit': + qw,qs = bnb.functional.quantize_4bit(w[k]) + w[k] = (bnb.functional.dequantize_4bit(qw,quant_state=qs)).to(dtype=torch.bfloat16) + elif quant=='nf4': + qw,qs = bnb.functional.quantize_nf4(w[k]) + w[k] = (bnb.functional.dequantize_nf4(qw,quant_state=qs)).to(dtype=torch.bfloat16) + elif quant=='fp4': + qw,qs = bnb.functional.quantize_fp4(w[k]) + w[k] = (bnb.functional.dequantize_fp4(qw,quant_state=qs)).to(dtype=torch.bfloat16) + w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) + output_w[k] = w[k].to(device='cpu', copy=True) + del w[k] + del w[lora_A] + del w[lora_B] + continue + + if 'lora' not in k: + print(f'retaining {k}') + output_w[k] = w[k].clone() + del w[k] + torch.save(output_w, output) \ No newline at end of file diff --git a/finetune/lora/v6/merge/merge_lora.py b/finetune/lora/v6/merge/merge_lora.py new file mode 100644 index 0000000..ee38fdf --- /dev/null +++ b/finetune/lora/v6/merge/merge_lora.py @@ -0,0 +1,52 @@ +from collections import OrderedDict +import os +import sys +from typing import Dict +import typing +import torch + +if '-h' in sys.argv or '--help' in sys.argv: + print(f'Usage: python3 {sys.argv[0]} [--use-gpu] ') + +if sys.argv[1] == '--use-gpu': + device = 'cuda' + lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5] +else: + device = 'cpu' + lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4] + + +with torch.no_grad(): + w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') + # merge LoRA-only slim checkpoint into the main weights + w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') + for k in w_lora.keys(): + w[k] = w_lora[k] + output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() + # merge LoRA weights + keys = list(w.keys()) + for k in keys: + if k.endswith('.weight'): + prefix = k[:-len('.weight')] + lora_A = prefix + '.lora_A' + lora_B = prefix + '.lora_B' + if lora_A in keys: + assert lora_B in keys + print(f'merging {lora_A} and {lora_B} into {k}') + assert w[lora_B].shape[1] == w[lora_A].shape[0] + lora_r = w[lora_B].shape[1] + w[k] = w[k].to(device=device) + w[lora_A] = w[lora_A].to(device=device) + w[lora_B] = w[lora_B].to(device=device) + w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) + output_w[k] = w[k].to(device='cpu', copy=True) + del w[k] + del w[lora_A] + del w[lora_B] + continue + + if 'lora' not in k: + print(f'retaining {k}') + output_w[k] = w[k].clone() + del w[k] + torch.save(output_w, output) diff --git a/finetune/lora/v6/merge/merge_pissa.py b/finetune/lora/v6/merge/merge_pissa.py new file mode 100644 index 0000000..001358d --- /dev/null +++ b/finetune/lora/v6/merge/merge_pissa.py @@ -0,0 +1,58 @@ +from collections import OrderedDict +import os +import sys +from typing import Dict +import typing +import torch + +if '-h' in sys.argv or '--help' in sys.argv: + print(f'Usage: python3 {sys.argv[0]} [--use-gpu] ') + +if sys.argv[1] == '--use-gpu': + device = 'cuda' + base_model, init_lora, lora, output = sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5] +else: + device = 'cpu' + base_model, init_lora, lora, output = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4] + + +with torch.no_grad(): + w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') + # merge LoRA-only slim checkpoint into the main weights + w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') + w_init_lora: Dict[str, torch.Tensor] = torch.load(init_lora, map_location='cpu') + for k in w_lora.keys(): + w[k] = w_lora[k] + output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() + # merge LoRA weights + keys = list(w.keys()) + for k in keys: + if k.endswith('.weight'): + prefix = k[:-len('.weight')] + lora_A = prefix + '.lora_A' + lora_B = prefix + '.lora_B' + init_lora_A = prefix + '.init_lora_A' + init_lora_B = prefix + '.init_lora_B' + if lora_A in keys: + assert lora_B in keys + print(f'merging {lora_A} and {lora_B} into {k}') + assert w[lora_B].shape[1] == w[lora_A].shape[0] + lora_r = w[lora_B].shape[1] + w[k] = w[k].to(device=device) + w[lora_A] = w[lora_A].to(device=device) + w[lora_B] = w[lora_B].to(device=device) + w_init_lora[init_lora_A] = w_init_lora[init_lora_A].to(device=device) + w_init_lora[init_lora_B] = w_init_lora[init_lora_B].to(device=device) + w[k] = (w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]).to(dtype=torch.bfloat16) + w[k] += w[lora_B] @ w[lora_A] + output_w[k] = w[k].to(device='cpu', copy=True) + del w[k] + del w[lora_A] + del w[lora_B] + continue + + if 'lora' not in k: + print(f'retaining {k}') + output_w[k] = w[k].clone() + del w[k] + torch.save(output_w, output) \ No newline at end of file diff --git a/finetune/lora/v6/merge/merge_state.py b/finetune/lora/v6/merge/merge_state.py new file mode 100644 index 0000000..51e4767 --- /dev/null +++ b/finetune/lora/v6/merge/merge_state.py @@ -0,0 +1,36 @@ +from collections import OrderedDict +import os +import sys +from typing import Dict +import typing +import torch +import bitsandbytes as bnb +from argparse import ArgumentParser + +parser = ArgumentParser() +parser.add_argument("--base_model", default="", type=str) +parser.add_argument("--state_checkpoint", default="", type=str) +parser.add_argument("--output", default="", type=str) +# parser.add_argument("--quant", default="none", type=str) +parser.add_argument("--device", default="cuda", type=str) +# parser.add_argument("--lora_alpha", default=16, type=int) +args = parser.parse_args() +device= args.device +base_model = args.base_model +state= args.state_checkpoint +output= args.output + + +with torch.no_grad(): + w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') + # merge LoRA-only slim checkpoint into the main weights + w_state: Dict[str, torch.Tensor] = torch.load(state, map_location='cpu') + + for k in w_state.keys(): + print(k) + w[k] = w_state[k] + # merge LoRA weights + for k in w.keys(): + print(k) + + torch.save(w, output) \ No newline at end of file diff --git a/finetune/lora/v6/requirements.txt b/finetune/lora/v6/requirements.txt new file mode 100644 index 0000000..b5ca526 --- /dev/null +++ b/finetune/lora/v6/requirements.txt @@ -0,0 +1,5 @@ +pytorch-lightning==1.9.5 +bitsandbytes +deepspeed +einops +triton==2.2.0 \ No newline at end of file diff --git a/finetune/lora/v6/src/binidx.py b/finetune/lora/v6/src/binidx.py index c2d60a1..8631f1e 100644 --- a/finetune/lora/v6/src/binidx.py +++ b/finetune/lora/v6/src/binidx.py @@ -270,8 +270,10 @@ class MMapIndexedDataset(torch.utils.data.Dataset): np_array = np.append(np_array, np_array0) return np_array - def only(self, idx): + def only(self, idx, length=None): ptr, size = self._index[idx] + if length < size: + size = length np_array = np.frombuffer( self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr ) diff --git a/finetune/lora/v6/src/dataset.py b/finetune/lora/v6/src/dataset.py index 949e28a..69c6b4d 100644 --- a/finetune/lora/v6/src/dataset.py +++ b/finetune/lora/v6/src/dataset.py @@ -179,8 +179,12 @@ class MyDataset(Dataset): if args.data_type == "binidx": if args.my_pile_version == 1: - dix = data.get(idx=0, offset=i, length=req_len).astype(int) - # dix = data.pad(idx=idx, length=req_len).astype(int) + if args.dataload == "pad": + dix = data.pad(idx=idx, length=req_len).astype(int) + elif args.dataload == "only": + dix = data.only(idx=idx, length=req_len).astype(int) + else: + dix = data.get(idx=0, offset=i, length=req_len).astype(int) else: # self.data : cutoff, chunk_count, data for j in range(len(data)): diff --git a/finetune/lora/v6/src/infctx_module.py b/finetune/lora/v6/src/infctx_module.py new file mode 100644 index 0000000..e24a74f --- /dev/null +++ b/finetune/lora/v6/src/infctx_module.py @@ -0,0 +1,52 @@ +import torch +######state +class TimeMixState: + def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor): + self.shift_state = shift_state + self.wkv_state = wkv_state + + +class ChannelMixState: + def __init__(self, shift_state: torch.Tensor): + self.shift_state = shift_state + + +class BlockState: + def __init__(self, time_mix_state: TimeMixState, + channel_mix_state: ChannelMixState): + self.time_mix_state = time_mix_state + self.channel_mix_state = channel_mix_state + +class BlockStateList: + + def __init__(self, shift_states, wkv_states): + self.wkv_states = wkv_states + self.shift_states = shift_states + + @staticmethod + def create(N, B, C, H, device, dtype): + result = BlockStateList.empty(N, B, C, H, device, dtype) + result.wkv_states[:] = 0 + result.wkv_states[:] = 0 + result.shift_states[:] = 0 + return result + + @staticmethod + def empty(N, B, C, H, device, dtype): + wkv_states = torch.empty((N, B, H, C//H, C//H), + device=device, + dtype=torch.bfloat16) + shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype) + return BlockStateList(shift_states, wkv_states) + + def __getitem__(self, layer: int): + return BlockState( + TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]), + ChannelMixState(self.shift_states[layer, 1])) + + def __setitem__(self, layer: int, state: BlockState): + self.shift_states[layer, 0] = state.time_mix_state.shift_state + self.wkv_states[layer] = state.time_mix_state.wkv_state + self.shift_states[layer, 1] = state.channel_mix_state.shift_state + + diff --git a/finetune/lora/v6/src/model.py b/finetune/lora/v6/src/model.py index dbeb0f8..6830bf4 100644 --- a/finetune/lora/v6/src/model.py +++ b/finetune/lora/v6/src/model.py @@ -18,7 +18,11 @@ from pytorch_lightning.strategies import DeepSpeedStrategy if importlib.util.find_spec("deepspeed"): import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam - +from torch._lowrank import svd_lowrank +import bitsandbytes as bnb +from .infctx_module import * +from einops import rearrange +from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6 LORA_CONFIG = { "r": 0, @@ -29,6 +33,7 @@ LORA_CONFIG = { class LoraLinear(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool): super().__init__() @@ -44,12 +49,96 @@ class LoraLinear(nn.Module): self.lora_B = nn.Parameter(torch.empty(out_features, r)) self.lora_dropout = nn.Dropout(dropout) self.scaling = alpha / r - + self.r = r nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) + self.pissa = False + self.is_quant = False + + def pissa_init(self, svd_niter): + + self.pissa = True + Ur, Sr, Vr = svd_lowrank(self.weight.data, self.r, niter=svd_niter) + Vhr = Vr.t() + lora_A = torch.diag(torch.sqrt(Sr)) @ Vhr + lora_B = Ur @ torch.diag(torch.sqrt(Sr)) + self.lora_A.data = lora_A + self.lora_B.data = lora_B + self.weight.data = self.weight.data - lora_B @ lora_A + + def quant(self, quant_type): + self.is_quant = True + self.quant_type = quant_type + if self.quant_type == "4bit": + self.weight.data, self.qstate = bnb.functional.quantize_4bit( + (self.weight.data).to("cuda") + ) + elif self.quant_type == "nf4": + self.weight.data, self.qstate = bnb.functional.quantize_nf4( + (self.weight.data).to("cuda") + ) + elif self.quant_type == "fp4": + self.weight.data, self.qstate = bnb.functional.quantize_fp4( + (self.weight.data).to("cuda") + ) def forward(self, x): + + if self.is_quant: + if self.quant_type == "4bit": + if self.pissa: + return F.linear( + x, + bnb.functional.dequantize_4bit( + self.weight.data, quant_state=self.qstate + ).to(torch.bfloat16), + ) + F.linear(F.linear(x, self.lora_A), self.lora_B) + return F.linear( + x, + bnb.functional.dequantize_4bit( + self.weight.data, quant_state=self.qstate + ).to(torch.bfloat16), + ) + self.scaling * F.linear( + F.linear(self.lora_dropout(x), self.lora_A), self.lora_B + ) + elif self.quant_type == "nf4": + if self.pissa: + return F.linear( + x, + bnb.functional.dequantize_nf4( + self.weight.data, quant_state=self.qstate + ).to(torch.bfloat16), + ) + F.linear(F.linear(x, self.lora_A), self.lora_B) + return F.linear( + x, + bnb.functional.dequantize_nf4( + self.weight.data, quant_state=self.qstate + ).to(torch.bfloat16), + ) + self.scaling * F.linear( + F.linear(self.lora_dropout(x), self.lora_A), self.lora_B + ) + elif self.quant_type == "fp4": + if self.pissa: + return F.linear( + x, + bnb.functional.dequantize_fp4( + self.weight.data, quant_state=self.qstate + ).to(torch.bfloat16), + ) + F.linear(F.linear(x, self.lora_A), self.lora_B) + return F.linear( + x, + bnb.functional.dequantize_fp4( + self.weight.data, quant_state=self.qstate + ).to(torch.bfloat16), + ) + self.scaling * F.linear( + F.linear(self.lora_dropout(x), self.lora_A), self.lora_B + ) + + if self.pissa: + return F.linear(x, self.weight) + F.linear( + F.linear(x, self.lora_A), self.lora_B + ) return F.linear(x, self.weight) + self.scaling * F.linear( F.linear(self.lora_dropout(x), self.lora_A), self.lora_B ) @@ -91,215 +180,487 @@ if os.environ["RWKV_JIT_ON"] == "1": ######################################################################################################## # CUDA Kernel ######################################################################################################## +if os.environ["WKV"] == "fla": + if "x060" in os.environ["RWKV_MY_TESTING"]: + if ( + os.environ["RWKV_TRAIN_TYPE"] == "infctx" + and "x060" in os.environ["RWKV_MY_TESTING"] + ): -from torch.utils.cpp_extension import load + def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s): + r = rearrange(r, "b l (h d) -> b h l d", h=H) + k = rearrange(k, "b l (h d) -> b h l d", h=H) + v = rearrange(v, "b l (h d) -> b h l d", h=H) + w = rearrange(-torch.exp(w), "b l (h d) -> b h l d", h=H) + o, state = chunk_rwkv6( + r, k, v, w, u=u, initial_state=s, output_final_state=True + ) + x = rearrange(o, "b h l d -> b l (h d)") + return x, state -HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"]) + elif os.environ["RWKV_TRAIN_TYPE"] == "states": -if "x060" in os.environ["RWKV_MY_TESTING"]: - wkv6_cuda = load( - name="wkv6", - sources=[ - "finetune/lora/v6/cuda/wkv6_op.cpp", - f"finetune/lora/v6/cuda/wkv6_cuda.cu", - ], - verbose=True, - extra_cuda_cflags=[ - "-res-usage", - "--use_fast_math", - "-O3", - "-Xptxas -O3", - "--extra-device-vectorization", - f"-D_N_={HEAD_SIZE}", - f"-D_T_={int(os.environ['RWKV_CTXLEN'])}", - ], - ) + def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s): + r = rearrange(r, "b l (h d) -> b h l d", h=H) + k = rearrange(k, "b l (h d) -> b h l d", h=H) + v = rearrange(v, "b l (h d) -> b h l d", h=H) + w = rearrange(-torch.exp(w), "b l (h d) -> b h l d", h=H) + o, _ = chunk_rwkv6( + r, k, v, w, u=u, initial_state=s, output_final_state=False + ) + x = rearrange(o, "b h l d -> b l (h d)") + return x - class WKV_6(torch.autograd.Function): - @staticmethod - def forward(ctx, B, T, C, H, r, k, v, w, u): - with torch.no_grad(): - assert r.dtype == torch.bfloat16 - assert k.dtype == torch.bfloat16 - assert v.dtype == torch.bfloat16 - assert w.dtype == torch.bfloat16 - assert u.dtype == torch.bfloat16 - assert HEAD_SIZE == C // H - ctx.B = B - ctx.T = T - ctx.C = C - ctx.H = H - assert r.is_contiguous() - assert k.is_contiguous() - assert v.is_contiguous() - assert w.is_contiguous() - assert u.is_contiguous() - ew = (-torch.exp(w.float())).contiguous() - ctx.save_for_backward(r, k, v, ew, u) - y = torch.empty( - (B, T, C), - device=r.device, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-100, 100) - wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y) - return y + else: - @staticmethod - def backward(ctx, gy): - with torch.no_grad(): - assert gy.dtype == torch.bfloat16 - B = ctx.B - T = ctx.T - C = ctx.C - H = ctx.H - assert gy.is_contiguous() - r, k, v, ew, u = ctx.saved_tensors - gr = torch.empty( - (B, T, C), - device=gy.device, - requires_grad=False, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-100, 100) - gk = torch.empty( - (B, T, C), - device=gy.device, - requires_grad=False, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-100, 100) - gv = torch.empty( - (B, T, C), - device=gy.device, - requires_grad=False, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-100, 100) - gw = torch.empty( - (B, T, C), - device=gy.device, - requires_grad=False, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-100, 100) - gu = torch.empty( - (B, C), - device=gy.device, - requires_grad=False, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-100, 100) - wkv6_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu) - gu = torch.sum(gu, 0).view(H, C // H) - return (None, None, None, None, gr, gk, gv, gw, gu) - - def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u): - return WKV_6.apply(B, T, C, H, r, k, v, w, u) + def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u): + r = rearrange(r, "b l (h d) -> b h l d", h=H) + k = rearrange(k, "b l (h d) -> b h l d", h=H) + v = rearrange(v, "b l (h d) -> b h l d", h=H) + w = rearrange(-torch.exp(w), "b l (h d) -> b h l d", h=H) + o, _ = chunk_rwkv6( + r, k, v, w, u=u, initial_state=None, output_final_state=False + ) + x = rearrange(o, "b h l d -> b l (h d)") + return x else: - wkv5_cuda = load( - name="wkv5", - sources=[ - "finetune/lora/v6/cuda/wkv5_op.cpp", - f"finetune/lora/v6/cuda/wkv5_cuda.cu", - ], - verbose=True, - extra_cuda_cflags=[ - "-res-usage", - "--use_fast_math", - "-O3", - "-Xptxas -O3", - "--extra-device-vectorization", - f"-D_N_={HEAD_SIZE}", - ], - ) + from torch.utils.cpp_extension import load - class WKV_5(torch.autograd.Function): - @staticmethod - def forward(ctx, B, T, C, H, r, k, v, w, u): - with torch.no_grad(): - assert r.dtype == torch.bfloat16 - assert k.dtype == torch.bfloat16 - assert v.dtype == torch.bfloat16 - assert w.dtype == torch.bfloat16 - assert u.dtype == torch.bfloat16 - assert HEAD_SIZE == C // H - ctx.B = B - ctx.T = T - ctx.C = C - ctx.H = H - assert r.is_contiguous() - assert k.is_contiguous() - assert v.is_contiguous() - assert w.is_contiguous() - assert u.is_contiguous() - ew = (-torch.exp(w.float())).contiguous() - eew = (torch.exp(ew)).contiguous() - ctx.save_for_backward(r, k, v, eew, ew, u) - y = torch.empty( - (B, T, C), - device=r.device, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-1, 1) - wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y) - return y + HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"]) - @staticmethod - def backward(ctx, gy): - with torch.no_grad(): - assert gy.dtype == torch.bfloat16 - B = ctx.B - T = ctx.T - C = ctx.C - H = ctx.H - assert gy.is_contiguous() - r, k, v, eew, ew, u = ctx.saved_tensors - gr = torch.empty( - (B, T, C), - device=gy.device, - requires_grad=False, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-1, 1) - gk = torch.empty( - (B, T, C), - device=gy.device, - requires_grad=False, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-1, 1) - gv = torch.empty( - (B, T, C), - device=gy.device, - requires_grad=False, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-1, 1) - gw = torch.empty( - (B, C), - device=gy.device, - requires_grad=False, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-1, 1) - gu = torch.empty( - (B, C), - device=gy.device, - requires_grad=False, - dtype=torch.bfloat16, - memory_format=torch.contiguous_format, - ) # .uniform_(-1, 1) - wkv5_cuda.backward( - B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu - ) - gw = torch.sum(gw, 0).view(H, C // H) - gu = torch.sum(gu, 0).view(H, C // H) - return (None, None, None, None, gr, gk, gv, gw, gu) + if "x060" in os.environ["RWKV_MY_TESTING"]: + if os.environ["RWKV_TRAIN_TYPE"] == "infctx": + wkv6state_cuda = load( + name="wkv6infctx", + sources=[ + "finetune/lora/v6/cuda/wkv6infctx_op.cpp", + f"finetune/lora/v6/cuda/wkv6infctx_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + f"-D_T_={int(os.environ['RWKV_CTXLEN'])}", + ], + ) - def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u): - return WKV_5.apply(B, T, C, H, r, k, v, w, u) + class WKV_6STATE(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, r, k, v, w, u, s): + with torch.no_grad(): + assert r.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + assert v.dtype == torch.bfloat16 + assert w.dtype == torch.bfloat16 + assert u.dtype == torch.bfloat16 + assert s.dtype == torch.bfloat16 + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + assert s.is_contiguous() + ctx.save_for_backward(r, k, v, w, u, s) + y = torch.empty( + (B, T, C), + device=r.device, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + wkv6state_cuda.forward(B, T, C, H, r, k, v, w, u, s, y) + return y + + @staticmethod + def backward(ctx, gy): + with torch.no_grad(): + assert gy.dtype == torch.bfloat16 + B = ctx.B + T = ctx.T + C = ctx.C + H = ctx.H + assert gy.is_contiguous() + r, k, v, w, u, s = ctx.saved_tensors + gr = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gk = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gv = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gw = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gu = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gs = torch.empty( + (B, H, C // H, C // H), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + wkv6state_cuda.backward( + B, T, C, H, r, k, v, w, u, s, gy, gr, gk, gv, gw, gu, gs + ) + gu = torch.sum(gu, 0).view(H, C // H) + gs = torch.sum(gs, 0).view(H, C // H, C // H) + return (None, None, None, None, gr, gk, gv, gw, gu, gs) + + def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s): + x = WKV_6STATE.apply(B, T, C, H, r, k, v, w, u, s) + return x, s + + elif os.environ["RWKV_TRAIN_TYPE"] == "states": + wkv6state_cuda = load( + name="wkv6state", + sources=[ + "finetune/lora/v6/cuda/wkv6state_op.cpp", + f"finetune/lora/v6/cuda/wkv6state_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + f"-D_T_={int(os.environ['RWKV_CTXLEN'])}", + ], + ) + + class WKV_6STATE(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, r, k, v, w, u, s): + with torch.no_grad(): + assert r.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + assert v.dtype == torch.bfloat16 + assert w.dtype == torch.bfloat16 + assert u.dtype == torch.bfloat16 + assert s.dtype == torch.bfloat16 + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + assert s.is_contiguous() + ctx.save_for_backward(r, k, v, w, u, s) + y = torch.empty( + (B, T, C), + device=r.device, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + wkv6state_cuda.forward(B, T, C, H, r, k, v, w, u, s, y) + return y + + @staticmethod + def backward(ctx, gy): + with torch.no_grad(): + assert gy.dtype == torch.bfloat16 + B = ctx.B + T = ctx.T + C = ctx.C + H = ctx.H + assert gy.is_contiguous() + r, k, v, w, u, s = ctx.saved_tensors + gr = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gk = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gv = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gw = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gu = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gs = torch.empty( + (B, H, C // H, C // H), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + wkv6state_cuda.backward( + B, T, C, H, r, k, v, w, u, s, gy, gr, gk, gv, gw, gu, gs + ) + gu = torch.sum(gu, 0).view(H, C // H) + gs = torch.sum(gs, 0).view(H, C // H, C // H) + return (None, None, None, None, gr, gk, gv, gw, gu, gs) + + def RUN_CUDA_RWKV6_STATE(B, T, C, H, r, k, v, w, u, s): + return WKV_6STATE.apply(B, T, C, H, r, k, v, w, u, s) + + else: + wkv6_cuda = load( + name="wkv6", + sources=[ + "finetune/lora/v6/cuda/wkv6_op.cpp", + f"finetune/lora/v6/cuda/wkv6_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + f"-D_T_={int(os.environ['RWKV_CTXLEN'])}", + ], + ) + + class WKV_6(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, r, k, v, w, u): + with torch.no_grad(): + assert r.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + assert v.dtype == torch.bfloat16 + assert w.dtype == torch.bfloat16 + assert u.dtype == torch.bfloat16 + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + ew = (-torch.exp(w.float())).contiguous() + ctx.save_for_backward(r, k, v, ew, u) + y = torch.empty( + (B, T, C), + device=r.device, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y) + return y + + @staticmethod + def backward(ctx, gy): + with torch.no_grad(): + assert gy.dtype == torch.bfloat16 + B = ctx.B + T = ctx.T + C = ctx.C + H = ctx.H + assert gy.is_contiguous() + r, k, v, ew, u = ctx.saved_tensors + gr = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gk = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gv = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gw = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gu = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + wkv6_cuda.backward( + B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu + ) + gu = torch.sum(gu, 0).view(H, C // H) + return (None, None, None, None, gr, gk, gv, gw, gu) + + def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u): + return WKV_6.apply(B, T, C, H, r, k, v, w, u) + + else: + wkv5_cuda = load( + name="wkv5", + sources=[ + "finetune/lora/v6/cuda/wkv5_op.cpp", + f"finetune/lora/v6/cuda/wkv5_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + ], + ) + + class WKV_5(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, r, k, v, w, u): + with torch.no_grad(): + assert r.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + assert v.dtype == torch.bfloat16 + assert w.dtype == torch.bfloat16 + assert u.dtype == torch.bfloat16 + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + ew = (-torch.exp(w.float())).contiguous() + eew = (torch.exp(ew)).contiguous() + ctx.save_for_backward(r, k, v, eew, ew, u) + y = torch.empty( + (B, T, C), + device=r.device, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y) + return y + + @staticmethod + def backward(ctx, gy): + with torch.no_grad(): + assert gy.dtype == torch.bfloat16 + B = ctx.B + T = ctx.T + C = ctx.C + H = ctx.H + assert gy.is_contiguous() + r, k, v, eew, ew, u = ctx.saved_tensors + gr = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gk = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gv = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gw = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + gu = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-1, 1) + wkv5_cuda.backward( + B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu + ) + gw = torch.sum(gw, 0).view(H, C // H) + gu = torch.sum(gu, 0).view(H, C // H) + return (None, None, None, None, gr, gk, gv, gw, gu) + + def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u): + return WKV_5.apply(B, T, C, H, r, k, v, w, u) ######################################################################################################## @@ -430,6 +791,8 @@ class RWKV_Tmix_x060(MyModule): ) TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g + if args.n_embd == 4096: + TIME_MIX_EXTRA_DIM = TIME_MIX_EXTRA_DIM * 2 self.time_maa_w1 = nn.Parameter( torch.zeros(args.n_embd, TIME_MIX_EXTRA_DIM * 5).uniform_(-1e-4, 1e-4) ) @@ -446,6 +809,8 @@ class RWKV_Tmix_x060(MyModule): self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att)) TIME_DECAY_EXTRA_DIM = 64 + if args.n_embd == 4096: + TIME_DECAY_EXTRA_DIM = TIME_DECAY_EXTRA_DIM * 2 self.time_decay_w1 = nn.Parameter( torch.zeros(args.n_embd, TIME_DECAY_EXTRA_DIM).uniform_(-1e-4, 1e-4) ) @@ -519,6 +884,133 @@ class RWKV_Tmix_x060(MyModule): ######################################################################################################## +class RWKV_Tmix_x060_state(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + + self.head_size = args.head_size_a + self.n_head = args.dim_att // self.head_size + assert args.dim_att % self.n_head == 0 + + with torch.no_grad(): + ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + + # fancy time_mix + self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_v = nn.Parameter( + 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + ) + self.time_maa_r = nn.Parameter( + 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) + ) + self.time_maa_g = nn.Parameter( + 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) + ) + + D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g + if args.n_embd == 4096: + D_MIX_LORA = D_MIX_LORA * 2 + self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5)) + self.time_maa_w2 = nn.Parameter( + torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01) + ) + + # fancy time_decay + decay_speed = torch.ones(args.dim_att) + for n in range(args.dim_att): + decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) + self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att)) + + D_DECAY_LORA = 64 + if args.n_embd == 4096: + D_DECAY_LORA = D_DECAY_LORA * 2 + self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA)) + self.time_decay_w2 = nn.Parameter( + torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01) + ) + + tmp = torch.zeros(args.dim_att) + for n in range(args.dim_att): + zigzag = ((n + 1) % 3 - 1) * 0.1 + tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag + + self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) + self.time_state = nn.Parameter( + torch.zeros(self.n_head, self.head_size, self.head_size) + ) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.key = nn.Linear(args.n_embd, args.dim_att, bias=False) + + self.value = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) + self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.ln_x = nn.GroupNorm( + self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2) + ) + + @MyFunction + def jit_func(self, x): + B, T, C = x.size() + + xx = self.time_shift(x) - x + + xxx = x + xx * self.time_maa_x + xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + xw = x + xx * (self.time_maa_w + mw) + xk = x + xx * (self.time_maa_k + mk) + xv = x + xx * (self.time_maa_v + mv) + xr = x + xx * (self.time_maa_r + mr) + xg = x + xx * (self.time_maa_g + mg) + + r = self.receptance(xr) + k = self.key(xk) + v = self.value(xv) + g = F.silu(self.gate(xg)) + + ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 + w = self.time_decay + ww + + return r, k, v, g, w + + @MyFunction + def jit_func_2(self, x, g): + B, T, C = x.size() + x = x.view(B * T, C) + + x = self.ln_x(x).view(B, T, C) + x = self.output(x * g) + return x + + def forward(self, x): + B, T, C = x.size() + H = self.n_head + + r, k, v, g, w = self.jit_func(x) + x = RUN_CUDA_RWKV6_STATE( + B, T, C, H, r, k, v, w, u=self.time_faaaa, s=self.time_state + ) + + return self.jit_func_2(x, g) + + +######################################################################################################## + + class RWKV_ChannelMix(MyModule): def __init__(self, args, layer_id): super().__init__() @@ -613,6 +1105,163 @@ class MishGLU(MyModule): return self.value(a * F.mish(b)) +######################################################################################################## + + +class RWKV_Tmix_x060_infctx(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + + self.head_size = args.head_size_a + self.n_head = args.dim_att // self.head_size + assert args.dim_att % self.n_head == 0 + + with torch.no_grad(): + ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + + # fancy time_mix + self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_v = nn.Parameter( + 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + ) + self.time_maa_r = nn.Parameter( + 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) + ) + self.time_maa_g = nn.Parameter( + 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) + ) + + D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g + if args.n_embd == 4096: + D_MIX_LORA = D_MIX_LORA * 2 + self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5)) + self.time_maa_w2 = nn.Parameter( + torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01) + ) + + # fancy time_decay + decay_speed = torch.ones(args.dim_att) + for n in range(args.dim_att): + decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) + self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att)) + + D_DECAY_LORA = 64 + if args.n_embd == 4096: + D_DECAY_LORA = D_DECAY_LORA * 2 + self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA)) + self.time_decay_w2 = nn.Parameter( + torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01) + ) + + tmp = torch.zeros(args.dim_att) + for n in range(args.dim_att): + zigzag = ((n + 1) % 3 - 1) * 0.1 + tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag + + self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) + # self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.key = nn.Linear(args.n_embd, args.dim_att, bias=False) + + self.value = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) + self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.ln_x = nn.GroupNorm( + self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2) + ) + + @MyFunction + def jit_func(self, x, shift_state): + B, T, C = x.size() + xx = torch.concat((shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x + + xxx = x + xx * self.time_maa_x + xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + xw = x + xx * (self.time_maa_w + mw) + xk = x + xx * (self.time_maa_k + mk) + xv = x + xx * (self.time_maa_v + mv) + xr = x + xx * (self.time_maa_r + mr) + xg = x + xx * (self.time_maa_g + mg) + + r = self.receptance(xr) + k = self.key(xk) + v = self.value(xv) + g = F.silu(self.gate(xg)) + + ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 + w = self.time_decay + ww + + return r, k, v, g, w, x[:, -1] + + @MyFunction + def jit_func_2(self, x, g, timemixstate: TimeMixState): + B, T, C = x.size() + x = x.view(B * T, C) + + x = self.ln_x(x).view(B, T, C) + x = self.output(x * g) + return x, timemixstate + + def forward(self, x, last_state: TimeMixState): + B, T, C = x.size() + H = self.n_head + shift_state = last_state.shift_state + r, k, v, g, w, lx = self.jit_func(x, shift_state) + ###### + wkv_state = last_state.wkv_state.clone().contiguous() + x, wkv_state = RUN_CUDA_RWKV6_STATE( + B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state + ) + # wkv_state = last_state.wkv_state + return self.jit_func_2(x, g, TimeMixState(lx, wkv_state)) + + +class RWKV_CMix_x060_infctx(MyModule): + def __init__(self, args, layer_id): + super().__init__() + self.args = args + self.layer_id = layer_id + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 + ddd = torch.ones(1, 1, args.n_embd) + for i in range(args.n_embd): + ddd[0, 0, i] = i / args.n_embd + self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + + self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False) + self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False) + self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False) + + @MyFunction + def forward(self, x, last_state: ChannelMixState): + xx = torch.concat((last_state.shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x + xk = x + xx * self.time_maa_k + xr = x + xx * self.time_maa_r + + k = self.key(xk) + k = torch.relu(k) ** 2 + kv = self.value(k) + return torch.sigmoid(self.receptance(xr)) * kv, ChannelMixState(x[:, -1]) + + ######################################################################################################## # The RWKV Model with our blocks ######################################################################################################## @@ -641,7 +1290,12 @@ class Block(nn.Module): self.ffnPre = RWKV_ChannelMix(args, 0) else: if "x060" in os.environ["RWKV_MY_TESTING"]: - self.att = RWKV_Tmix_x060(args, layer_id) + if os.environ["RWKV_TRAIN_TYPE"] == "states": + self.att = RWKV_Tmix_x060_state(args, layer_id) + elif os.environ["RWKV_TRAIN_TYPE"] == "infctx": + self.att = RWKV_Tmix_x060_infctx(args, layer_id) + else: + self.att = RWKV_Tmix_x060(args, layer_id) else: self.att = RWKV_TimeMix_RWKV5(args, layer_id) @@ -649,7 +1303,10 @@ class Block(nn.Module): self.ffn = MishGLU(args, layer_id) else: if "x060" in os.environ["RWKV_MY_TESTING"]: - self.ffn = RWKV_CMix_x060(args, layer_id) + if os.environ["RWKV_TRAIN_TYPE"] == "infctx": + self.ffn = RWKV_CMix_x060_infctx(args, layer_id) + else: + self.ffn = RWKV_CMix_x060(args, layer_id) else: self.ffn = RWKV_ChannelMix(args, layer_id) @@ -666,53 +1323,125 @@ class Block(nn.Module): self.drop0 = nn.Dropout(p=args.dropout) self.drop1 = nn.Dropout(p=args.dropout) - def forward(self, x, x_emb=None): - args = self.args - B, T, C = x.size() - if self.layer_id == 0: - x = self.ln0(x) - if args.my_pos_emb > 0: - pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :] - x = x + pos_emb + if os.environ["RWKV_TRAIN_TYPE"] == "infctx": - if self.args.dropout == 0: - if self.layer_id == 0 and args.pre_ffn > 0: - x = x + self.ffnPre(self.ln1(x)) + def forward(self, x, last_state: BlockState, x_emb=None): + args = self.args + B, T, C = x.size() + if self.layer_id == 0: + x = self.ln0(x) + if args.my_pos_emb > 0: + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[ + :-1, : + ] + x = x + pos_emb + + if self.args.dropout == 0: + if self.layer_id == 0 and args.pre_ffn > 0: + x = x + self.ffnPre(self.ln1(x)) + else: + att_out, att_state = self.att( + self.ln1(x), last_state.time_mix_state + ) + x = x + att_out + ffn_out, fnn_state = self.ffn(self.ln2(x), last_state.channel_mix_state) + x = x + ffn_out else: - x = x + self.att(self.ln1(x)) - x = x + self.ffn(self.ln2(x)) - else: - if self.layer_id == 0 and args.pre_ffn > 0: - x = self.drop0(x + self.ffnPre(self.ln1(x))) + if self.layer_id == 0 and args.pre_ffn > 0: + x = self.drop0(x + self.ffnPre(self.ln1(x))) + else: + x = self.drop0(x + self.att(self.ln1(x))) + x = self.drop1(x + self.ffn(self.ln2(x))) + + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: + xx = self.tiny_ln(x) + q = self.tiny_q(xx)[:, :T, :] + k = self.tiny_k(xx)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) + c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) + x = x + c @ self.tiny_v(x_emb) + return x, BlockState(att_state, fnn_state) + + else: + + def forward(self, x, x_emb=None): + args = self.args + B, T, C = x.size() + if self.layer_id == 0: + x = self.ln0(x) + if args.my_pos_emb > 0: + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[ + :-1, : + ] + x = x + pos_emb + + if self.args.dropout == 0: + if self.layer_id == 0 and args.pre_ffn > 0: + x = x + self.ffnPre(self.ln1(x)) + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) else: - x = self.drop0(x + self.att(self.ln1(x))) - x = self.drop1(x + self.ffn(self.ln2(x))) + if self.layer_id == 0 and args.pre_ffn > 0: + x = self.drop0(x + self.ffnPre(self.ln1(x))) + else: + x = self.drop0(x + self.att(self.ln1(x))) + x = self.drop1(x + self.ffn(self.ln2(x))) - if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: - xx = self.tiny_ln(x) - q = self.tiny_q(xx)[:, :T, :] - k = self.tiny_k(xx)[:, :T, :] - c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) - c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) - x = x + c @ self.tiny_v(x_emb) - return x + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: + xx = self.tiny_ln(x) + q = self.tiny_q(xx)[:, :T, :] + k = self.tiny_k(xx)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) + c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) + x = x + c @ self.tiny_v(x_emb) + return x -class L2Wrap(torch.autograd.Function): - @staticmethod - def forward(ctx, loss, y): - ctx.save_for_backward(y) - return loss +if os.environ["RWKV_TRAIN_TYPE"] == "infctx": - @staticmethod - def backward(ctx, grad_output): - y = ctx.saved_tensors[0] - # to encourage the logits to be close to 0 - factor = 1e-4 / (y.shape[0] * y.shape[1]) - maxx, ids = torch.max(y, -1, keepdim=True) - gy = torch.zeros_like(y) - gy.scatter_(-1, ids, maxx * factor) - return (grad_output, gy) + class L2Wrap(torch.autograd.Function): + @staticmethod + def forward(ctx, loss, y, token_amount): + ctx.save_for_backward(y) + ctx.token_amount = token_amount + return loss + + @staticmethod + def backward( + ctx, grad_output + ): # 这个函数会不会影响batch和grad_accu的一致性?感觉上会。梯度累积时,factor变大了。但是只有loss缩放,这里的正则化项反而没有缩放 + y = ctx.saved_tensors[0] + # to encourage the logits to be close to 0 + if ctx.token_amount == 0: + return (grad_output, None, None) + factor = 1e-4 / ctx.token_amount # 这一行类似crossentropy在token上平均。 + maxx, ids = torch.max(y, -1, keepdim=True) + gy = torch.zeros_like(y) + if os.environ.get("WN_FIX_L2WRAP"): # 实现batch等价性 + # maxx[maxx<3.]=0. #防止对已经较小的logits值下拉,只对大于阈值的往下拉 + gy.scatter_(-1, ids, maxx * factor * grad_output) + else: + gy.scatter_(-1, ids, maxx * factor) + return (grad_output, gy, None) + +else: + + class L2Wrap(torch.autograd.Function): + @staticmethod + def forward(ctx, loss, y): + ctx.save_for_backward(y) + return loss + + @staticmethod + def backward(ctx, grad_output): + y = ctx.saved_tensors[0] + # to encourage the logits to be close to 0 + factor = 1e-4 / (y.shape[0] * y.shape[1]) + maxx, ids = torch.max(y, -1, keepdim=True) + gy = torch.zeros_like(y) + gy.scatter_(-1, ids, maxx * factor) + return (grad_output, gy) class RWKV(pl.LightningModule): @@ -755,6 +1484,8 @@ class RWKV(pl.LightningModule): lr_2x = set() lr_3x = set() for n, p in self.named_parameters(): + if not p.requires_grad: + continue if (("_w1" in n) or ("_w2" in n)) and (args.layerwise_lr > 0): lr_1x.add(n) elif (("time_mix" in n) or ("time_maa" in n)) and (args.layerwise_lr > 0): @@ -896,100 +1627,237 @@ class RWKV(pl.LightningModule): return cfg.get("offload_optimizer") or cfg.get("offload_param") return False - def forward(self, idx): - args = self.args - B, T = idx.size() - assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." + if os.environ["RWKV_TRAIN_TYPE"] == "infctx": - x = self.emb(idx) - x_emb = x + def forward( + self, idx, last_shift_states: torch.Tensor, last_wkv_states: torch.Tensor + ): + args = self.args + B, T = idx.size() + assert T <= args.chunk_ctx, "Cannot forward, model ctx_len is exhausted." + C = args.n_embd + H = args.dim_att // args.head_size_a + assert C == H * args.head_size_a - if args.dropout > 0: - x = self.drop0(x) - if args.tiny_att_dim > 0: - for block in self.blocks: - if args.grad_cp == 1: - if args.lora: - x = torch_checkpoint(block, x, x_emb, use_reentrant=False) - else: - x = deepspeed.checkpointing.checkpoint(block, x, x_emb) + x = self.emb(idx) + x_emb = x + new_states = BlockStateList.empty( + args.n_layer, B, args.n_embd, H, x.device, x.dtype + ) + if args.dropout > 0: + x = self.drop0(x) + + for i, (block, block_state) in enumerate( + zip(self.blocks, BlockStateList(last_shift_states, last_wkv_states)) + ): + # x = x.to(block.device) + if args.grad_cp == 1 and i > 0: # and i < len(self.blocks)-1 + x, new_block_state = torch_checkpoint( + block, x, block_state, use_reentrant=False + ) else: - x = block(x, x_emb) - else: - for block in self.blocks: - if args.grad_cp == 1: - if args.lora: - x = torch_checkpoint(block, x, x_emb, use_reentrant=False) - else: - x = deepspeed.checkpointing.checkpoint(block, x) - else: - x = block(x) + x, new_block_state = block(x, block_state) + new_states[i] = new_block_state - x = self.ln_out(x) + x = self.ln_out(x) - if args.head_qk > 0: - q = self.head_q(x)[:, :T, :] - k = self.head_k(x)[:, :T, :] - c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) - c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + if args.head_qk > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) - if "32" in os.environ["RWKV_FLOAT_MODE"]: - c = c @ F.one_hot(idx, num_classes=args.vocab_size) - elif os.environ["RWKV_FLOAT_MODE"] == "fp16": - c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() - elif os.environ["RWKV_FLOAT_MODE"] == "bf16": - c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() + if "32" in os.environ["RWKV_FLOAT_MODE"]: + c = c @ F.one_hot(idx, num_classes=args.vocab_size) + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() - x = self.head(x) + c - else: - x = self.head(x) + x = self.head(x) + c + else: + x = self.head(x) - return x + return x, new_states.shift_states, new_states.wkv_states - def training_step(self, batch, batch_idx): - args = self.args - if args.my_qa_mask != 1: + def training_step(self, batch, batch_idx): + args = self.args + T_train = args.chunk_ctx idx, targets = batch - logits = self(idx) - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) - # if '0' in os.environ["RWKV_MY_TESTING"]: - # print('logits', logits) - # torch.set_printoptions(threshold=10000) - # print('idx', idx) - # exit(0) - else: - idx, targets, mask = batch - mask = mask.view(-1) - sum_mask = torch.sum(mask).item() - # if sum_mask == 0: - # return torch.tensor([0.0], requires_grad=True) + B, T = idx.shape + C = args.n_embd + H = args.dim_att // args.head_size_a + assert C == H * args.head_size_a + states = BlockStateList.create( + args.n_layer, B, C, H, idx.device, self.emb.weight.dtype + ) - logits = self(idx) - if sum_mask == mask.shape[0]: + def checkpointed_step( + idx, + targets, + prev_loss, + last_shift_states, + last_wkv_states, + prev_token_amount, + ): + logits, new_shift_states, new_wkv_states = self( + idx, last_shift_states, last_wkv_states + ) + current_token_amount = (targets != -100).sum() # 这样是不是更合适? + current_token_amount = idx.shape[1] + if current_token_amount == 0: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.reshape(-1), + reduction="sum", + ) + else: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.reshape(-1) + ) + loss = L2Wrap.apply(loss, logits, current_token_amount) + new_token_amount = prev_token_amount + current_token_amount + if new_token_amount > 0: + new_loss = prev_loss * ( + prev_token_amount / new_token_amount + ) + loss * (current_token_amount / new_token_amount) + else: + new_loss = prev_loss + + return new_loss, new_shift_states, new_wkv_states, new_token_amount + + total_loss = torch.tensor(0.0, dtype=self.emb.weight.dtype).requires_grad_() + token_amount = 0 + i = 0 + for i in range(math.ceil(T / T_train)): + # states.shift_states = states.shift_states.cuda() + # states.wkv_states = states.wkv_states.cuda() + total_loss, new_shift_states, new_wkv_states, token_amount = ( + torch_checkpoint( + checkpointed_step, + idx[:, i * T_train : (i + 1) * T_train], + targets[:, i * T_train : (i + 1) * T_train], + total_loss, + states.shift_states, + states.wkv_states, + token_amount, + use_reentrant=False, + ) + ) + # total_loss,new_shift_states, new_wkv_states,token_amount = checkpointed_step( + # idx[:, i * T_train:(i + 1) * T_train], + # targets[:, i * T_train:(i + 1) * T_train], + # total_loss, + # states.shift_states, + # states.wkv_states, + # token_amount + # ) + # new_shift_states = new_shift_states.cpu() + # new_wkv_states = new_wkv_states.cpu() + states = BlockStateList(new_shift_states, new_wkv_states) + + return total_loss + + else: + + def forward(self, idx): + args = self.args + B, T = idx.size() + assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." + + x = self.emb(idx) + x_emb = x + + if args.dropout > 0: + x = self.drop0(x) + if args.tiny_att_dim > 0: + for block in self.blocks: + if args.grad_cp == 1: + if args.lora or args.state_tune or args.train_type == "state": + x = torch_checkpoint(block, x, x_emb, use_reentrant=False) + else: + x = deepspeed.checkpointing.checkpoint(block, x, x_emb) + else: + x = block(x, x_emb) + else: + for block in self.blocks: + if args.grad_cp == 1: + if args.lora or args.state_tune or args.train_type == "state": + x = torch_checkpoint(block, x, x_emb, use_reentrant=False) + else: + x = deepspeed.checkpointing.checkpoint(block, x) + else: + x = block(x) + + x = self.ln_out(x) + + if args.head_qk > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + + if "32" in os.environ["RWKV_FLOAT_MODE"]: + c = c @ F.one_hot(idx, num_classes=args.vocab_size) + elif os.environ["RWKV_FLOAT_MODE"] == "fp16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() + elif os.environ["RWKV_FLOAT_MODE"] == "bf16": + c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() + + x = self.head(x) + c + else: + x = self.head(x) + + return x + + def training_step(self, batch, batch_idx): + args = self.args + if args.my_qa_mask != 1: + idx, targets = batch + logits = self(idx) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1) ) - # print('rank', self.global_rank, 'loss', loss.item()) + # if '0' in os.environ["RWKV_MY_TESTING"]: + # print('logits', logits) + # torch.set_printoptions(threshold=10000) + # print('idx', idx) + # exit(0) else: - loss = F.cross_entropy( - logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none" - ) - # loss_raw = loss - loss = torch.sum(loss * mask) / sum_mask + idx, targets, mask = batch + mask = mask.view(-1) + sum_mask = torch.sum(mask).item() + # if sum_mask == 0: + # return torch.tensor([0.0], requires_grad=True) - # torch.set_printoptions(threshold=10000) - # if True: #self.global_rank == 1: - # tmp = '' - # sss = 0 - # ccc = 0 - # for i in range(mask.shape[0]): - # if mask[i] > 0: - # tmp += str(idx.view(-1)[i].item()) + ',' - # sss += loss_raw.view(-1)[i].float().item() - # ccc += 1 - # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx) + logits = self(idx) + if sum_mask == mask.shape[0]: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1) + ) + # print('rank', self.global_rank, 'loss', loss.item()) + else: + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + reduction="none", + ) + # loss_raw = loss + loss = torch.sum(loss * mask) / sum_mask - return L2Wrap.apply(loss, logits) + # torch.set_printoptions(threshold=10000) + # if True: #self.global_rank == 1: + # tmp = '' + # sss = 0 + # ccc = 0 + # for i in range(mask.shape[0]): + # if mask[i] > 0: + # tmp += str(idx.view(-1)[i].item()) + ',' + # sss += loss_raw.view(-1)[i].float().item() + # ccc += 1 + # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx) + + return L2Wrap.apply(loss, logits) def training_step_end(self, batch_parts): if pl.__version__[0] != "2": diff --git a/finetune/lora/v6/src/trainer.py b/finetune/lora/v6/src/trainer.py index e14e7fc..ac40cd7 100644 --- a/finetune/lora/v6/src/trainer.py +++ b/finetune/lora/v6/src/trainer.py @@ -4,6 +4,8 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from .model import LORA_CONFIG +import re +import numpy as np def my_save(args, trainer, dd, ff): @@ -21,10 +23,7 @@ def my_save(args, trainer, dd, ff): f" aws s3 mv {fff} s3://rwkv-world/{aa}-{fn} --quiet", shell=True ) else: - if "deepspeed_stage_3" in args.strategy: - trainer.save_checkpoint(ff, weights_only=True) - else: - torch.save(dd, ff) + torch.save(dd, ff) class train_callback(pl.Callback): @@ -181,6 +180,30 @@ class train_callback(pl.Callback): to_save_dict, f"{args.proj_dir}/rwkv-final.pth", ) + + if args.LISA and (batch_idx + 1) % args.lisa_k == 0: + pl_module.requires_grad_(False) + select_layers = np.random.choice( + range(args.n_layer), args.lisa_r, replace=False + ) + + for name, module in pl_module.named_modules(): + for pname, param in module.named_parameters(): + if ( + "emb" in pname + or "head" in pname + or ".ln" in pname + or "time" in pname + ): + param.requires_grad = True + elif "ln_out" in pname: + param.requires_grad = True + match = re.search(r"\d+", pname) + if match: + number = int(match.group()) + if number in select_layers: + param.requires_grad = True + break # if args.batch_save==batch_idx : # to_save_dict = pl_module.state_dict() # for name, state in to_save_dict.items(): @@ -229,12 +252,22 @@ class train_callback(pl.Callback): if "img" in name: to_save_dict[name] = state + if args.state_tune or args.train_type == "state": + lora_dict = {} + for name, state in to_save_dict.items(): + if "state" in name: + lora_dict[name] = state + to_save_dict = lora_dict + if args.lora: enable_time_finetune = "time" in LORA_CONFIG["parts"] enable_ln_finetune = "ln" in LORA_CONFIG["parts"] lora_dict = {} for name, state in to_save_dict.items(): - if "img" in name: + if len(args.load_model) == 0: + if "emb" in name or "head" in name or "ln" in name: + lora_dict[name] = state + if args.emb and "emb" in name: lora_dict[name] = state if ( ".lora_" in name diff --git a/finetune/lora/v6/train.py b/finetune/lora/v6/train.py index d41f37d..01a10ad 100644 --- a/finetune/lora/v6/train.py +++ b/finetune/lora/v6/train.py @@ -1,6 +1,7 @@ ######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## +import os import logging @@ -110,7 +111,7 @@ if __name__ == "__main__": parser.add_argument("--magic_prime", default=0, type=int) parser.add_argument("--my_qa_mask", default=0, type=int) parser.add_argument("--my_random_steps", default=0, type=int) - parser.add_argument("--my_testing", default="", type=str) + parser.add_argument("--my_testing", default="x052", type=str) parser.add_argument("--my_exit", default=99999999, type=int) parser.add_argument("--my_exit_tokens", default=0, type=int) @@ -123,6 +124,29 @@ if __name__ == "__main__": parser.add_argument("--lora_dropout", default=0.01, type=float) parser.add_argument("--lora_parts", default="att,ln,time", type=str) + # LISA + parser.add_argument("--LISA", action="store_true") + parser.add_argument("--lisa_r", default=2, type=int) + parser.add_argument("--lisa_k", default=100, type=int) + + # PISSA + parser.add_argument("--PISSA", action="store_true") + parser.add_argument("--svd_niter", default=4, type=int) + + # quant + parser.add_argument("--quant", default="none", type=str) + + # dataset + parser.add_argument("--dataload", default="get", type=str) + + # state tuning + parser.add_argument("--state_tune", action="store_true") + + parser.add_argument("--chunk_ctx", default=512, type=int) + # fla + parser.add_argument("--fla", action="store_true") + parser.add_argument("--train_type", default="none", type=str) + if pl.__version__[0] == "2": parser.add_argument("--accelerator", default="gpu", type=str) parser.add_argument("--strategy", default="auto", type=str) @@ -175,6 +199,14 @@ if __name__ == "__main__": os.environ["RWKV_MY_TESTING"] = args.my_testing os.environ["RWKV_CTXLEN"] = str(args.ctx_len) os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a) + ######state tuning + os.environ["RWKV_TRAIN_TYPE"] = "" + if args.train_type == "state": + os.environ["RWKV_TRAIN_TYPE"] = "states" + elif args.train_type == "infctx": + os.environ["RWKV_TRAIN_TYPE"] = "infctx" + + os.environ["WKV"] = "fla" if args.fla else "" if args.dim_att <= 0: args.dim_att = args.n_embd if args.dim_ffn <= 0: @@ -323,11 +355,68 @@ if __name__ == "__main__": enable_time_finetune = "time" in LORA_CONFIG["parts"] enable_ln_finetune = "ln" in LORA_CONFIG["parts"] model = RWKV(args) - - if args.lora: + freeze = False + if args.lora or args.LISA or args.train_type == "state": model.requires_grad_(False) - for name, module in model.named_modules(): + freeze = True + if args.state_tune or args.train_type == "state": + for name, module in model.named_modules(): + for pname, param in module.named_parameters(): + if "state" in pname: + param.requires_grad = True + break + + if args.LISA: + import re + + select_layers = np.random.choice( + range(args.n_layer), args.lisa_r, replace=False + ) + for name, module in model.named_modules(): + for pname, param in module.named_parameters(): + if ( + "emb" in pname + or "head" in pname + or ".ln" in pname + or "time" in pname + ): + param.requires_grad = True + match = re.search(r"\d+", pname) + if match: + number = int(match.group()) + if number in select_layers: + param.requires_grad = True + break + + elif args.lora: + + for name, module in model.named_modules(): + if len(args.load_model) == 0: + if any(n.startswith("emb.") for n, _ in module.named_parameters()): + for pname, param in module.named_parameters(): + if "emb.weight" == pname: + print(f" EMB additionally training module {pname}") + param.requires_grad = True + if any(n.startswith("head.") for n, _ in module.named_parameters()): + for pname, param in module.named_parameters(): + if "head.weight" == pname: + print(f" head additionally training module {pname}") + param.requires_grad = True + if "ln" in name: + print(f" LoRA additionally training module {name}") + for param in module.parameters(): + param.requires_grad = True + if any(n.startswith("emb.") for n, _ in module.named_parameters()): + for pname, param in module.named_parameters(): + if args.emb and "emb.weight" == pname: + print(f" EMB additionally training module {pname}") + param.requires_grad = True + if any(n.startswith("head.") for n, _ in module.named_parameters()): + for pname, param in module.named_parameters(): + if args.emb and "head.weight" == pname: + print(f" head additionally training module {pname}") + param.requires_grad = True if any(n.startswith("lora_") for n, _ in module.named_parameters()): print(f" LoRA additionally training module {name}") for pname, param in module.named_parameters(): @@ -376,11 +465,26 @@ if __name__ == "__main__": for k in model.state_dict(): if k not in load_keys: load_dict[k] = model.state_dict()[k] - model.load_state_dict(load_dict, strict=(not args.lora)) + model.load_state_dict(load_dict, strict=(not freeze)) if os.path.isfile(args.lora_load): model.load_state_dict( torch.load(args.lora_load, map_location="cpu"), strict=False ) + if args.PISSA: + init_dict = {} + rank_zero_info(f"########## Init PISSA... ##########") + for name, m in model.named_modules(): + if hasattr(m, "pissa_init") and callable(getattr(m, "pissa_init")): + m.pissa_init(args.svd_niter) + init_dict[f"{name}.init_lora_A"] = m.lora_A.data + init_dict[f"{name}.init_lora_B"] = m.lora_B.data + torch.save(init_dict, f"{args.proj_dir}/init_lora.pth") + + if args.quant != "none": + rank_zero_info(f"########## Quant... ##########") + for name, m in model.named_modules(): + if hasattr(m, "quant") and callable(getattr(m, "quant")): + m.quant(args.quant) if pl.__version__[0] == "2": trainer = Trainer( @@ -434,3 +538,73 @@ if __name__ == "__main__": ) trainer.fit(model, data_loader) + # if args.LISA: + # args.load_model=f'rwkv-0.pth' + # model = RWKV(args) + # model.requires_grad_(False) + + # select_layers = np.random.choice(range(args.n_layer), args.lisa_r, replace=False) + # for name, module in model.named_modules(): + # for pname, param in module.named_parameters(): + # if 'emb' in pname or 'head' in pname or '.ln' in pname or 'time' in pname : + # param.requires_grad = True + # match = re.search(r'\d+', pname) + # if match: + # number = int(match.group()) + # if number in select_layers: + # param.requires_grad = True + # break + # rank_zero_info(f"########## Loading {args.load_model}... ##########") + # try: + # load_dict = torch.load(args.load_model, map_location="cpu") + # load_keys = list(load_dict.keys()) + # for k in load_keys: + # if k.startswith('_forward_module.'): + # load_dict[k.replace('_forward_module.','')] = load_dict[k] + # del load_dict[k] + # except: + # rank_zero_info(f"Bad checkpoint {args.load_model}") + # if args.my_pile_stage >= 2: # try again using another checkpoint + # max_p = args.my_pile_prev_p + # if max_p == -1: + # args.load_model = f"{args.proj_dir}/rwkv-init.pth" + # else: + # args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" + # args.epoch_begin = max_p + 1 + # rank_zero_info(f"Trying {args.load_model}") + # load_dict = torch.load(args.load_model, map_location="cpu") + + # if args.load_partial == 1: + # load_keys = load_dict.keys() + # for k in model.state_dict(): + # if k not in load_keys: + # load_dict[k] = model.state_dict()[k] + # model.load_state_dict(load_dict, strict=(not args.lora)) + + # if pl.__version__[0]=='2': + # trainer = Trainer(accelerator=args.accelerator,strategy=args.strategy,devices=args.devices,num_nodes=args.num_nodes,precision=args.precision, + # logger=args.logger,callbacks=[train_callback(args)],max_epochs=args.max_epochs,check_val_every_n_epoch=args.check_val_every_n_epoch,num_sanity_val_steps=args.num_sanity_val_steps, + # log_every_n_steps=args.log_every_n_steps,enable_checkpointing=args.enable_checkpointing,accumulate_grad_batches=args.accumulate_grad_batches,gradient_clip_val=args.gradient_clip_val) + # else: + # trainer = Trainer.from_argparse_args( + # args, + # callbacks=[train_callback(args)], + # ) + + # if trainer.global_rank == 0: + # for n in model.state_dict(): + # shape = model.state_dict()[n].shape + # shape = [i for i in shape if i != 1] + # if len(shape) > 1: + # print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + # else: + # print(f"{str(shape[0]).ljust(5)} {n}") + + # if "deepspeed" in args.strategy: + # trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + # trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + + # # must set shuffle=False, persistent_workers=False (because worker is in another thread) + # data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) + + # trainer.fit(model, data_loader) diff --git a/finetune/requirements.txt b/finetune/requirements.txt index 09c962b..c9d3ecb 100644 --- a/finetune/requirements.txt +++ b/finetune/requirements.txt @@ -1,3 +1,7 @@ -torch==2.1.2 +torch==2.2.0 pytorch_lightning==1.9.5 deepspeed==0.12.6 +bitsandbytes==0.43.1 +einops==0.8.0 +triton==2.2.0 +transformers==4.41.1 \ No newline at end of file