upgrade rwkv 0.8.16 (DirectML support; rwkv 5.2 no longer needs to ensure custom cuda kernel enabled)

This commit is contained in:
josc146
2023-10-25 17:56:18 +08:00
parent 2acdaa96b2
commit 0331bf47f7
10 changed files with 106 additions and 523 deletions

View File

@@ -220,7 +220,7 @@ class RWKV(MyModule):
else:
prxxx = lambda *args, **kwargs: None
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
if not re.match(STRATEGY_REGEX, strategy):
raise ValueError(
"Invalid strategy. Please read https://pypi.org/project/rwkv/"
@@ -372,6 +372,10 @@ class RWKV(MyModule):
strategy[n].atype = s[i][1][0]
strategy[n].wtype = s[i][1][1]
strategy[n].stream = False
if strategy[n].device == "dml":
import torch_directml
strategy[n].device = torch_directml.device()
if i == stream_i and n >= (plan[i] - stream_count):
strategy[n].stream = True
break
@@ -577,10 +581,7 @@ class RWKV(MyModule):
prxxx(f"Converted and saved. Now this will exit.")
exit(0)
if self.version == 5.2:
assert (
os.environ["RWKV_CUDA_ON"] == "1"
), "Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ['RWKV_CUDA_ON'] == '1' (will fix soon)"
if self.version == 5.2 and os.environ["RWKV_CUDA_ON"] == "1":
HEAD_SIZE = args.n_att // args.n_head
if LoadPreCompileLibrary("rwkv5") is True:
rwkv5 = torch.ops.rwkv5
@@ -1363,6 +1364,7 @@ class RWKV(MyModule):
########################################################################################################
@MyFunction
def att_seq_v5_2(
self,
x,
@@ -1408,29 +1410,29 @@ class RWKV(MyModule):
gx = xx * g_mix + sx * (1 - g_mix)
H = t_decay.shape[0]
N = x.shape[-1] // H
S = x.shape[-1] // H
T = x.shape[0]
r = gemm(rx, rw, output_dtype=torch.float32)
k = gemm(kx, kw, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32)
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
k = (
gemm(kx, kw, output_dtype=torch.float32)
.view(T, H, S)
.transpose(0, 1)
.transpose(-2, -1)
)
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
g = F.silu(gemm(gx, gw))
out, s = self.RUN_RWKV_5(
1,
T,
self.args.n_att,
H,
s.transpose(-1, -2).contiguous(),
r,
k,
v,
w=t_decay,
u=t_first,
)
s = s.transpose(-1, -2)
out = torch.empty((T, H, S), dtype=r.dtype, device=r.device)
for t in range(T):
rt = r[:, t : t + 1, :]
kt = k[:, :, t : t + 1]
vt = v[:, t : t + 1, :]
at = gemm(kt, vt)
out[t] = (rt @ (t_first * at + s)).squeeze(1)
s = at + t_decay * s
out = out.reshape(T, H * N)
out = out.reshape(T, H * S)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype) * g
out = gemm(out, ow)
@@ -1543,6 +1545,81 @@ class RWKV(MyModule):
out = self.mm8_seq(r * y, ow, omx, orx, omy, ory)
return x + out, xx[-1, :], aa, bb, pp
# NOTE: decorate with @MyFunction causes JIT error
def cuda_att_seq_v5_2(
self,
x,
sx,
s,
ln_w,
ln_b,
lx_w,
lx_b,
k_mix,
v_mix,
r_mix,
g_mix,
t_decay,
t_first,
kw,
vw,
rw,
gw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
kx = xx * k_mix + sx * (1 - k_mix)
vx = xx * v_mix + sx * (1 - v_mix)
rx = xx * r_mix + sx * (1 - r_mix)
gx = xx * g_mix + sx * (1 - g_mix)
H = t_decay.shape[0]
N = x.shape[-1] // H
T = x.shape[0]
r = gemm(rx, rw, output_dtype=torch.float32)
k = gemm(kx, kw, output_dtype=torch.float32)
v = gemm(vx, vw, output_dtype=torch.float32)
g = F.silu(gemm(gx, gw))
out, s = self.RUN_RWKV_5(
1,
T,
self.args.n_att,
H,
s.transpose(-1, -2).contiguous(),
r,
k,
v,
w=t_decay,
u=t_first,
)
s = s.transpose(-1, -2)
out = out.reshape(T, H * N)
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
out = out.to(dtype=x.dtype) * g
out = gemm(out, ow)
return x + out, xx[-1, :], s
########################################################################################################
def forward(self, tokens, state, full_output=False):
@@ -1622,7 +1699,10 @@ class RWKV(MyModule):
atype = dd.atype
wtype = dd.wtype
if seq_mode:
if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1":
cuda_applicable = os.environ[
"RWKV_CUDA_ON"
] == "1" and "cuda" in str(dev)
if cuda_applicable:
ATT = (
self.cuda_att_seq
if wtype != torch.uint8
@@ -1636,6 +1716,8 @@ class RWKV(MyModule):
ATT = self.att_seq_v5_1
elif self.version == 5.2:
ATT = self.att_seq_v5_2
if cuda_applicable:
ATT = self.cuda_att_seq_v5_2
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
else:
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8