upgrade cuda-beta

This commit is contained in:
josc146
2023-09-15 16:30:11 +08:00
parent c4042bbfd8
commit df969fcfc6
6 changed files with 601 additions and 90 deletions

View File

@@ -3,7 +3,7 @@
########################################################################################################
from typing import Optional
import types, gc, os, time, re
import types, gc, os, time, re, platform
import torch
from torch.nn import functional as F
@@ -91,6 +91,7 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
f"{current_path}/cuda/att_one.cu",
f"{current_path}/cuda/att_seq.cu",
f"{current_path}/cuda/ffn.cu",
f"{current_path}/cuda/att_one_v5.cu",
],
verbose=True,
extra_cuda_cflags=[
@@ -149,26 +150,40 @@ if os.environ.get("RWKV_CUDA_ON") == "1":
torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y)
return y.to(dtype=x.dtype)
else:
os.environ["RWKV_CUDA_ON"] = "0"
if os.environ.get("RWKV_CUDA_ON") == "1":
@MyStatic
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
if output_dtype is None:
output_dtype = a.dtype
if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
assert len(b.shape) == 2
if len(a.shape) == 1:
assert len(b.shape) == 2
c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device)
a = a.unsqueeze(0)
else:
c = torch.empty(
(a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device
)
assert len(a.shape) == len(b.shape)
assert len(a.shape) == 2 or len(a.shape) == 3
# torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit
if len(a.shape) == 2:
c = torch.empty(
(a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device
)
else:
c = torch.empty(
(a.shape[0], a.shape[1], b.shape[-1]),
dtype=output_dtype,
device=a.device,
)
torch.ops.rwkv.gemm_fp16_cublas(a, b, c)
return c
else:
return (a @ b).to(output_dtype)
else:
os.environ["RWKV_CUDA_ON"] = "0"
def gemm(a, b, output_dtype: Optional[torch.dtype] = None):
if output_dtype is None:
@@ -217,7 +232,7 @@ class RWKV(MyModule):
) # load model to CPU first
# it is supported to load a pure meta-tensor state dict (e.g. for quick testing)
for k, v in self.w.items():
if v.is_meta:
if isinstance(v, torch.Tensor) and v.is_meta:
# torch.zeros_like(v, device='cpu') doesn't produce an all-zero tensor
# if v is a meta tensor
self.w[k] = torch.zeros(v.shape, dtype=v.dtype, device="cpu")
@@ -247,9 +262,14 @@ class RWKV(MyModule):
args.n_embd = w["emb.weight"].shape[1]
args.n_layer = 0
keys = list(w.keys())
self.version = 4
for x in keys:
layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
args.n_layer = max(args.n_layer, layer_id + 1)
if "ln_x" in x:
self.version = 5
if self.version == 5 and "att.time_decay" in x:
args.n_head = w[x].shape[0]
####################### Compute strategy
@@ -352,6 +372,20 @@ class RWKV(MyModule):
del w["blocks.0.ln0.bias"]
print_need_newline = False
REAL_TIME_FIRST = False
for x in list(w.keys()):
if ".time_faaaa" in x:
REAL_TIME_FIRST = True
if REAL_TIME_FIRST:
w = {
k.replace(".time_faaaa", ".time_first")
if ".time_faaaa" in k
else k: v
for k, v in w.items()
}
self.w = w
keys = list(w.keys())
for x in keys:
w[x].requires_grad = False
@@ -382,8 +416,19 @@ class RWKV(MyModule):
w[x] = w[x].t()
if ".time_decay" in x: # need fp32 for this
w[x] = -torch.exp(w[x].float())
if self.version == 4:
w[x] = -torch.exp(w[x].float())
elif self.version == 5:
w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1, 1, 1)
elif ".time_first" in x: # need fp32 for this
if self.version == 4:
w[x] = w[x].float()
elif self.version == 5:
if REAL_TIME_FIRST:
w[x] = w[x].float().reshape(-1, 1, 1)
else:
w[x] = torch.exp(w[x].float()).reshape(-1, 1, 1)
elif ".ln_x" in x: # need fp32 for group_norm
w[x] = w[x].float()
else:
if (len(w[x].shape) == 2) and ("emb" not in x):
@@ -931,6 +976,147 @@ class RWKV(MyModule):
########################################################################################################
@MyFunction
def att_one_v5(
self,
x,
sx,
s,
ln_w,
ln_b,
lx_w,
lx_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
H = t_decay.shape[0]
S = x.shape[-1] // H
r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S)
k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1)
v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S)
a = gemm(k, v)
out = r @ (t_first * a + s)
s = a + t_decay * s
out = out.flatten()
out = F.group_norm(
out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b
).squeeze(0)
out = out.to(dtype=x.dtype)
out = gemm(out, ow)
return x + out, xx, s
@MyFunction
def att_seq_v5(
self,
x,
sx,
s,
ln_w,
ln_b,
lx_w,
lx_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
H = t_decay.shape[0]
S = x.shape[-1] // H
T = x.shape[0]
w = t_decay.reshape(-1, 1)
u = t_first.reshape(-1, 1)
ws = w.pow(T).reshape(H, 1, 1)
ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1)
w = w.repeat(1, T).pow(ind)
wk = w.reshape(H, 1, T)
wb = wk.transpose(-2, -1).flip(1)
w = torch.cat([w[:, 1:], u], dim=1)
w = F.pad(w, (0, T))
w = torch.tile(w, [T])
w = w[:, :-T].reshape(-1, T, 2 * T - 1)
w = w[:, :, T - 1 :].reshape(H, T, T)
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
k = (
gemm(kx, kw, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
.transpose(-2, -1)
)
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
out = ((r @ k) * w) @ v + (r @ s) * wb
s = ws * s + (k * wk) @ v
out = out.transpose(0, 1).contiguous().reshape(T, H * S)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype)
out = gemm(out, ow)
return x + out, xx[-1, :], s
########################################################################################################
if os.environ["RWKV_CUDA_ON"] == "1":
@MyFunction
@@ -1140,7 +1326,7 @@ class RWKV(MyModule):
xx = torch.ops.rwkv.ffn_seq(
x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, buf, x_plus_out
)
return x_plus_out, xx[-1:]
return x_plus_out, xx[-1, :]
@MyFunction
def cuda_att_one_fp16(
@@ -1220,6 +1406,86 @@ class RWKV(MyModule):
)
return x_plus_out_t, xx, t1_t, t2_t, p_t
@MyFunction
def cuda_att_one_v5_fp16(
self,
x,
sx,
s,
ln_w,
ln_b,
lx_w,
lx_b,
k_mix,
v_mix,
r_mix,
t_decay,
t_first,
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
kx = torch.empty_like(x)
vx = torch.empty_like(x)
rx = torch.empty_like(x)
H = t_decay.shape[0]
S = x.shape[-1] // H
r = torch.empty((H * S,), dtype=torch.float32, device=x.device)
k = torch.empty((H * S,), dtype=torch.float32, device=x.device)
v = torch.empty((H * S,), dtype=torch.float32, device=x.device)
s1 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
s2 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
x_plus_out = torch.empty_like(x)
xx = torch.ops.rwkv.att_one_v5(
x,
sx,
s,
ln_w,
ln_b,
lx_w,
lx_b,
k_mix,
v_mix,
r_mix,
kw,
kx,
vw,
vx,
rw,
rx,
ow,
t_first,
k,
t_decay,
v,
r,
s1,
x_plus_out,
s2,
)
return x_plus_out, xx, s2
@MyFunction
def cuda_ffn_one_fp16(
self,
@@ -1265,34 +1531,63 @@ class RWKV(MyModule):
args = self.args
if state == None:
state = [None] * args.n_layer * 5
for i in range(
args.n_layer
): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
state[i * 5 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 1] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 2] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 3] = (
torch.zeros(
if self.version == 4:
state = [None] * args.n_layer * 5
for i in range(
args.n_layer
): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
state[i * 5 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 1] = torch.zeros(
args.n_embd,
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous()
- 1e30
)
state[i * 5 + 4] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 2] = torch.zeros(
args.n_embd,
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous()
state[i * 5 + 3] = (
torch.zeros(
args.n_embd,
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous()
- 1e30
)
state[i * 5 + 4] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
elif self.version == 5:
state = [None] * args.n_layer * 3
for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
state[i * 3 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state[i * 3 + 1] = torch.zeros(
(
args.n_head,
args.n_embd // args.n_head,
args.n_embd // args.n_head,
),
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous()
state[i * 3 + 2] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
seq_mode = len(tokens) > 1
@@ -1317,9 +1612,13 @@ class RWKV(MyModule):
ATT = self.cuda_att_seq_i8
else:
ATT = self.cuda_att_seq_naive
if self.version == 5:
ATT = self.att_seq_v5
else:
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
if self.version == 5:
ATT = self.att_one_v5
if (
"cuda" in str(dev)
and os.environ["RWKV_CUDA_ON"] == "1"
@@ -1327,6 +1626,8 @@ class RWKV(MyModule):
):
ATT = self.cuda_att_one_fp16
FFN = self.cuda_ffn_one_fp16
if self.version == 5:
ATT = self.cuda_att_one_v5_fp16
x = x.to(dtype=atype, device=dev)
@@ -1355,46 +1656,82 @@ class RWKV(MyModule):
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
) = ATT(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
w[f"{bbb}ln1.weight"],
w[f"{bbb}ln1.bias"],
w[f"{att}time_mix_k"],
w[f"{att}time_mix_v"],
w[f"{att}time_mix_r"],
w[f"{att}time_decay"],
w[f"{att}time_first"],
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
)
if self.version == 4:
(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
) = ATT(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
w[f"{bbb}ln1.weight"],
w[f"{bbb}ln1.bias"],
w[f"{att}time_mix_k"],
w[f"{att}time_mix_v"],
w[f"{att}time_mix_r"],
w[f"{att}time_decay"],
w[f"{att}time_first"],
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
)
elif self.version == 5:
x, state[i * 3 + 0], state[i * 3 + 1] = ATT(
x,
state[i * 3 + 0],
state[i * 3 + 1],
w[f"{bbb}ln1.weight"],
w[f"{bbb}ln1.bias"],
w[f"{att}ln_x.weight"],
w[f"{att}ln_x.bias"],
w[f"{att}time_mix_k"],
w[f"{att}time_mix_v"],
w[f"{att}time_mix_r"],
w[f"{att}time_decay"],
w[f"{att}time_first"],
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
)
if dd.stream:
del kw, vw, rw, ow
@@ -1417,9 +1754,13 @@ class RWKV(MyModule):
rrx = w[f"{ffn}receptance.weight_rx"] if wtype == torch.uint8 else x
rmy = w[f"{ffn}receptance.weight_my"] if wtype == torch.uint8 else x
rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x
x, state[i * 5 + 4] = FFN(
if self.version == 4:
offset = i * 5 + 4
elif self.version == 5:
offset = i * 3 + 2
x, state[offset] = FFN(
x,
state[i * 5 + 4],
state[offset],
w[f"{bbb}ln2.weight"],
w[f"{bbb}ln2.bias"],
w[f"{ffn}time_mix_k"],