upgrade rwkv 0.8.16 (DirectML support; rwkv 5.2 no longer needs to ensure custom cuda kernel enabled)
This commit is contained in:
130
backend-python/rwkv_pip/model.py
vendored
130
backend-python/rwkv_pip/model.py
vendored
@@ -220,7 +220,7 @@ class RWKV(MyModule):
|
||||
else:
|
||||
prxxx = lambda *args, **kwargs: None
|
||||
|
||||
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
|
||||
STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
|
||||
if not re.match(STRATEGY_REGEX, strategy):
|
||||
raise ValueError(
|
||||
"Invalid strategy. Please read https://pypi.org/project/rwkv/"
|
||||
@@ -372,6 +372,10 @@ class RWKV(MyModule):
|
||||
strategy[n].atype = s[i][1][0]
|
||||
strategy[n].wtype = s[i][1][1]
|
||||
strategy[n].stream = False
|
||||
if strategy[n].device == "dml":
|
||||
import torch_directml
|
||||
|
||||
strategy[n].device = torch_directml.device()
|
||||
if i == stream_i and n >= (plan[i] - stream_count):
|
||||
strategy[n].stream = True
|
||||
break
|
||||
@@ -577,10 +581,7 @@ class RWKV(MyModule):
|
||||
prxxx(f"Converted and saved. Now this will exit.")
|
||||
exit(0)
|
||||
|
||||
if self.version == 5.2:
|
||||
assert (
|
||||
os.environ["RWKV_CUDA_ON"] == "1"
|
||||
), "Please Enable Custom CUDA Kernel. Latest RWKV-5 requires os.environ['RWKV_CUDA_ON'] == '1' (will fix soon)"
|
||||
if self.version == 5.2 and os.environ["RWKV_CUDA_ON"] == "1":
|
||||
HEAD_SIZE = args.n_att // args.n_head
|
||||
if LoadPreCompileLibrary("rwkv5") is True:
|
||||
rwkv5 = torch.ops.rwkv5
|
||||
@@ -1363,6 +1364,7 @@ class RWKV(MyModule):
|
||||
|
||||
########################################################################################################
|
||||
|
||||
@MyFunction
|
||||
def att_seq_v5_2(
|
||||
self,
|
||||
x,
|
||||
@@ -1408,29 +1410,29 @@ class RWKV(MyModule):
|
||||
gx = xx * g_mix + sx * (1 - g_mix)
|
||||
|
||||
H = t_decay.shape[0]
|
||||
N = x.shape[-1] // H
|
||||
S = x.shape[-1] // H
|
||||
T = x.shape[0]
|
||||
|
||||
r = gemm(rx, rw, output_dtype=torch.float32)
|
||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
||||
r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||
k = (
|
||||
gemm(kx, kw, output_dtype=torch.float32)
|
||||
.view(T, H, S)
|
||||
.transpose(0, 1)
|
||||
.transpose(-2, -1)
|
||||
)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1)
|
||||
g = F.silu(gemm(gx, gw))
|
||||
|
||||
out, s = self.RUN_RWKV_5(
|
||||
1,
|
||||
T,
|
||||
self.args.n_att,
|
||||
H,
|
||||
s.transpose(-1, -2).contiguous(),
|
||||
r,
|
||||
k,
|
||||
v,
|
||||
w=t_decay,
|
||||
u=t_first,
|
||||
)
|
||||
s = s.transpose(-1, -2)
|
||||
out = torch.empty((T, H, S), dtype=r.dtype, device=r.device)
|
||||
for t in range(T):
|
||||
rt = r[:, t : t + 1, :]
|
||||
kt = k[:, :, t : t + 1]
|
||||
vt = v[:, t : t + 1, :]
|
||||
at = gemm(kt, vt)
|
||||
out[t] = (rt @ (t_first * at + s)).squeeze(1)
|
||||
s = at + t_decay * s
|
||||
|
||||
out = out.reshape(T, H * N)
|
||||
out = out.reshape(T, H * S)
|
||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||
out = out.to(dtype=x.dtype) * g
|
||||
out = gemm(out, ow)
|
||||
@@ -1543,6 +1545,81 @@ class RWKV(MyModule):
|
||||
out = self.mm8_seq(r * y, ow, omx, orx, omy, ory)
|
||||
return x + out, xx[-1, :], aa, bb, pp
|
||||
|
||||
# NOTE: decorate with @MyFunction causes JIT error
|
||||
def cuda_att_seq_v5_2(
|
||||
self,
|
||||
x,
|
||||
sx,
|
||||
s,
|
||||
ln_w,
|
||||
ln_b,
|
||||
lx_w,
|
||||
lx_b,
|
||||
k_mix,
|
||||
v_mix,
|
||||
r_mix,
|
||||
g_mix,
|
||||
t_decay,
|
||||
t_first,
|
||||
kw,
|
||||
vw,
|
||||
rw,
|
||||
gw,
|
||||
ow,
|
||||
kmx,
|
||||
krx,
|
||||
kmy,
|
||||
kry,
|
||||
vmx,
|
||||
vrx,
|
||||
vmy,
|
||||
vry,
|
||||
rmx,
|
||||
rrx,
|
||||
rmy,
|
||||
rry,
|
||||
omx,
|
||||
orx,
|
||||
omy,
|
||||
ory,
|
||||
):
|
||||
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
|
||||
sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
|
||||
kx = xx * k_mix + sx * (1 - k_mix)
|
||||
vx = xx * v_mix + sx * (1 - v_mix)
|
||||
rx = xx * r_mix + sx * (1 - r_mix)
|
||||
gx = xx * g_mix + sx * (1 - g_mix)
|
||||
|
||||
H = t_decay.shape[0]
|
||||
N = x.shape[-1] // H
|
||||
T = x.shape[0]
|
||||
|
||||
r = gemm(rx, rw, output_dtype=torch.float32)
|
||||
k = gemm(kx, kw, output_dtype=torch.float32)
|
||||
v = gemm(vx, vw, output_dtype=torch.float32)
|
||||
g = F.silu(gemm(gx, gw))
|
||||
|
||||
out, s = self.RUN_RWKV_5(
|
||||
1,
|
||||
T,
|
||||
self.args.n_att,
|
||||
H,
|
||||
s.transpose(-1, -2).contiguous(),
|
||||
r,
|
||||
k,
|
||||
v,
|
||||
w=t_decay,
|
||||
u=t_first,
|
||||
)
|
||||
s = s.transpose(-1, -2)
|
||||
|
||||
out = out.reshape(T, H * N)
|
||||
out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b)
|
||||
out = out.to(dtype=x.dtype) * g
|
||||
out = gemm(out, ow)
|
||||
|
||||
return x + out, xx[-1, :], s
|
||||
|
||||
########################################################################################################
|
||||
|
||||
def forward(self, tokens, state, full_output=False):
|
||||
@@ -1622,7 +1699,10 @@ class RWKV(MyModule):
|
||||
atype = dd.atype
|
||||
wtype = dd.wtype
|
||||
if seq_mode:
|
||||
if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1":
|
||||
cuda_applicable = os.environ[
|
||||
"RWKV_CUDA_ON"
|
||||
] == "1" and "cuda" in str(dev)
|
||||
if cuda_applicable:
|
||||
ATT = (
|
||||
self.cuda_att_seq
|
||||
if wtype != torch.uint8
|
||||
@@ -1636,6 +1716,8 @@ class RWKV(MyModule):
|
||||
ATT = self.att_seq_v5_1
|
||||
elif self.version == 5.2:
|
||||
ATT = self.att_seq_v5_2
|
||||
if cuda_applicable:
|
||||
ATT = self.cuda_att_seq_v5_2
|
||||
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
|
||||
else:
|
||||
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
|
||||
|
||||
Reference in New Issue
Block a user