######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import functools import os, math, gc, importlib import torch # torch._C._jit_set_profiling_executor(True) # torch._C._jit_set_profiling_mode(True) import torch.nn as nn from torch.utils.checkpoint import checkpoint as torch_checkpoint from torch.nn import functional as F import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy if importlib.util.find_spec("deepspeed"): import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam # lora-config LORA_CONFIG = { "r": 0, "alpha": 0, "dropout": 0, "parts": {"att", "ln", "time"}, } try: print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"]) except: os.environ["RWKV_MY_TESTING"] = "" def __nop(ob): return ob MyModule = nn.Module MyFunction = __nop if os.environ["RWKV_JIT_ON"] == "1": MyModule = torch.jit.ScriptModule MyFunction = torch.jit.script_method ######################################################################################################## # CUDA Kernel ######################################################################################################## from torch.utils.cpp_extension import load HEAD_SIZE = int(os.environ["RWKV_HEAD_SIZE_A"]) wkv5_cuda = load( name="wkv5", sources=[ "finetune/lora/v5/cuda/wkv5_op.cpp", f"finetune/lora/v5/cuda/wkv5_cuda.cu", ], verbose=True, extra_cuda_cflags=[ "-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", ], ) class WKV_5(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, H, r, k, v, w, u): with torch.no_grad(): assert r.dtype == torch.bfloat16 assert k.dtype == torch.bfloat16 assert v.dtype == torch.bfloat16 assert w.dtype == torch.bfloat16 assert u.dtype == torch.bfloat16 assert HEAD_SIZE == C // H ctx.B = B ctx.T = T ctx.C = C ctx.H = H assert r.is_contiguous() assert k.is_contiguous() assert v.is_contiguous() assert w.is_contiguous() assert u.is_contiguous() ew = (-torch.exp(w.float())).contiguous() eew = (torch.exp(ew)).contiguous() ctx.save_for_backward(r, k, v, eew, ew, u) y = torch.empty( (B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format, ) # .uniform_(-1, 1) wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y) return y @staticmethod def backward(ctx, gy): with torch.no_grad(): assert gy.dtype == torch.bfloat16 B = ctx.B T = ctx.T C = ctx.C H = ctx.H assert gy.is_contiguous() r, k, v, eew, ew, u = ctx.saved_tensors gr = torch.empty( (B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format, ) # .uniform_(-1, 1) gk = torch.empty( (B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format, ) # .uniform_(-1, 1) gv = torch.empty( (B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format, ) # .uniform_(-1, 1) gw = torch.empty( (B, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format, ) # .uniform_(-1, 1) gu = torch.empty( (B, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format, ) # .uniform_(-1, 1) wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu) gw = torch.sum(gw, 0).view(H, C // H) gu = torch.sum(gu, 0).view(H, C // H) return (None, None, None, None, gr, gk, gv, gw, gu) def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u): return WKV_5.apply(B, T, C, H, r, k, v, w, u) ################################################################# class LoraLinear(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool): super().__init__() self.weight = nn.Parameter(torch.empty((out_features, in_features))) assert bias == False, "Biased LoraLinear not supported" r, alpha, dropout = ( LORA_CONFIG["r"], LORA_CONFIG["alpha"], LORA_CONFIG["dropout"], ) self.lora_A = nn.Parameter(torch.empty(r, in_features)) self.lora_B = nn.Parameter(torch.empty(out_features, r)) self.lora_dropout = nn.Dropout(dropout) self.scaling = alpha / r nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def forward(self, x): return F.linear(x, self.weight) + self.scaling * F.linear( F.linear(self.lora_dropout(x), self.lora_A), self.lora_B ) @functools.wraps(LoraLinear) def make_linear_att(*args, **kwargs): if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0: return LoraLinear(*args, **kwargs) else: return nn.Linear(*args, **kwargs) @functools.wraps(LoraLinear) def make_linear_ffn(*args, **kwargs): if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0: return LoraLinear(*args, **kwargs) else: return nn.Linear(*args, **kwargs) ######################################################################################################## class RWKV_TimeMix_RWKV5(MyModule): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.head_size = args.head_size_a assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a self.n_head = args.dim_att // self.head_size assert args.dim_att % self.n_head == 0 self.head_size_divisor = args.head_size_divisor with torch.no_grad(): ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 ddd = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): ddd[0, 0, i] = i / args.n_embd # fancy time_mix self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_v = nn.Parameter( torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 ) self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) # fancy time_decay decay_speed = torch.ones(args.dim_att) for n in range(args.dim_att): decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** ( 0.7 + 1.3 * ratio_0_to_1 ) self.time_decay = nn.Parameter( decay_speed.reshape(self.n_head, self.head_size) ) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) tmp = torch.zeros(args.dim_att) for n in range(args.dim_att): zigzag = ((n + 1) % 3 - 1) * 0.1 tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False) self.key = make_linear_att(args.n_embd, args.dim_att, bias=False) self.value = make_linear_att(args.n_embd, args.dim_att, bias=False) self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False) self.ln_x = nn.GroupNorm(self.n_head, args.dim_att) @MyFunction def jit_func(self, x): B, T, C = x.size() xx = self.time_shift( x ) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) xg = x * self.time_mix_g + xx * (1 - self.time_mix_g) r = self.receptance(xr) k = self.key(xk) v = self.value(xv) g = F.silu(self.gate(xg)) return r, k, v, g @MyFunction def jit_func_2(self, x, g): B, T, C = x.size() x = x.view(B * T, C) x = self.ln_x(x / self.head_size_divisor).view(B, T, C) x = self.output(x * g) return x def forward(self, x): B, T, C = x.size() H = self.n_head r, k, v, g = self.jit_func(x) x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa) return self.jit_func_2(x, g) ######################################################################################################## class RWKV_ChannelMix(MyModule): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): # fancy init of time_mix ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 ddd = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): ddd[0, 0, i] = i / args.n_embd self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.key = make_linear_ffn(args.n_embd, args.dim_ffn, bias=False) self.receptance = make_linear_ffn(args.n_embd, args.n_embd, bias=False) self.value = make_linear_ffn(args.dim_ffn, args.n_embd, bias=False) @MyFunction def forward(self, x): xx = self.time_shift(x) xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) k = self.key(xk) k = torch.relu(k) ** 2 kv = self.value(k) return torch.sigmoid(self.receptance(xr)) * kv class MishGLU(MyModule): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) with torch.no_grad(): ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) x = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): x[0, 0, i] = i / args.n_embd self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) self.aa = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.bb = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) @MyFunction def forward(self, x): xx = self.time_shift(x) xa = x * self.time_mix_k + xx * (1 - self.time_mix_k) xb = x * self.time_mix_r + xx * (1 - self.time_mix_r) a = self.aa(xa) b = self.bb(xb) return self.value(a * F.mish(b)) ######################################################################################################## # The RWKV Model with our blocks ######################################################################################################## class Block(nn.Module): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.ln1 = nn.LayerNorm(args.n_embd) self.ln2 = nn.LayerNorm(args.n_embd) if self.layer_id == 0: self.ln0 = nn.LayerNorm(args.n_embd) if args.my_pos_emb > 0: self.pos_emb_x = nn.Parameter( torch.zeros((1, args.my_pos_emb, args.n_embd)) ) self.pos_emb_y = nn.Parameter( torch.zeros((args.my_pos_emb, 1, args.n_embd)) ) if self.layer_id == 0 and self.args.pre_ffn > 0: self.ffnPre = RWKV_ChannelMix(args, 0) else: self.att = RWKV_TimeMix_RWKV5(args, layer_id) if "g" in os.environ["RWKV_MY_TESTING"]: self.ffn = MishGLU(args, layer_id) else: self.ffn = RWKV_ChannelMix(args, layer_id) if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: self.tiny_ln = nn.LayerNorm(args.n_embd) self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False) self.register_buffer( "tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) ) if args.dropout > 0: self.drop0 = nn.Dropout(p=args.dropout) self.drop1 = nn.Dropout(p=args.dropout) def forward(self, x, x_emb=None): args = self.args B, T, C = x.size() if self.layer_id == 0: x = self.ln0(x) if args.my_pos_emb > 0: pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :] x = x + pos_emb if self.args.dropout == 0: if self.layer_id == 0 and args.pre_ffn > 0: x = x + self.ffnPre(self.ln1(x)) else: x = x + self.att(self.ln1(x)) x = x + self.ffn(self.ln2(x)) else: if self.layer_id == 0 and args.pre_ffn > 0: x = self.drop0(x + self.ffnPre(self.ln1(x))) else: x = self.drop0(x + self.att(self.ln1(x))) x = self.drop1(x + self.ffn(self.ln2(x))) if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: xx = self.tiny_ln(x) q = self.tiny_q(xx)[:, :T, :] k = self.tiny_k(xx)[:, :T, :] c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5)) c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0) x = x + c @ self.tiny_v(x_emb) return x class L2Wrap(torch.autograd.Function): @staticmethod def forward(ctx, loss, y): ctx.save_for_backward(y) return loss @staticmethod def backward(ctx, grad_output): y = ctx.saved_tensors[0] # to encourage the logits to be close to 0 factor = 1e-4 / (y.shape[0] * y.shape[1]) maxx, ids = torch.max(y, -1, keepdim=True) gy = torch.zeros_like(y) gy.scatter_(-1, ids, maxx * factor) return (grad_output, gy) class RWKV(pl.LightningModule): def __init__(self, args): super().__init__() self.args = args if not hasattr(args, "dim_att"): args.dim_att = args.n_embd if not hasattr(args, "dim_ffn"): args.dim_ffn = args.n_embd * 4 if not hasattr(args, "tiny_att_layer"): args.tiny_att_layer = -1 if not hasattr(args, "tiny_att_dim"): args.tiny_att_dim = -1 assert args.n_embd % 32 == 0 assert args.dim_att % 32 == 0 assert args.dim_ffn % 32 == 0 self.emb = nn.Embedding(args.vocab_size, args.n_embd) self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) self.ln_out = nn.LayerNorm(args.n_embd) self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) if args.head_qk > 0: self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False) self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) self.register_buffer( "copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) ) if args.dropout > 0: self.drop0 = nn.Dropout(p=args.dropout) def configure_optimizers(self): args = self.args lr_decay = set() lr_1x = set() lr_2x = set() lr_3x = set() for n, p in self.named_parameters(): if ("time_mix" in n) and (args.layerwise_lr > 0): if args.my_pile_stage == 2: lr_2x.add(n) else: lr_1x.add(n) elif ("time_decay" in n) and (args.layerwise_lr > 0): if args.my_pile_stage == 2: lr_3x.add(n) else: lr_2x.add(n) elif ("time_faaaa" in n) and (args.layerwise_lr > 0): if args.my_pile_stage == 2: lr_2x.add(n) else: lr_1x.add(n) elif ("time_first" in n) and (args.layerwise_lr > 0): lr_3x.add(n) elif (len(p.squeeze().shape) >= 2) and (args.weight_decay > 0): lr_decay.add(n) else: lr_1x.add(n) lr_decay = sorted(list(lr_decay)) lr_1x = sorted(list(lr_1x)) lr_2x = sorted(list(lr_2x)) lr_3x = sorted(list(lr_3x)) # print('decay', lr_decay) # print('1x', lr_1x) # print('2x', lr_2x) # print('3x', lr_3x) param_dict = {n: p for n, p in self.named_parameters()} if args.layerwise_lr > 0: if args.my_pile_stage == 2: optim_groups = [ { "params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0, }, { "params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0, }, # test: 2e-3 / args.lr_init}, { "params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0, }, # test: 3e-3 / args.lr_init}, ] else: optim_groups = [ { "params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0, }, { "params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0, }, { "params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0, }, ] else: optim_groups = [ { "params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0, } ] if args.weight_decay > 0: optim_groups += [ { "params": [param_dict[n] for n in lr_decay], "weight_decay": args.weight_decay, "my_lr_scale": 1.0, } ] if self.deepspeed_offload: return DeepSpeedCPUAdam( optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=True, amsgrad=False, ) return FusedAdam( optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=True, amsgrad=False, ) else: if self.deepspeed_offload: return DeepSpeedCPUAdam( optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False, ) return FusedAdam( optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False, ) # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) @property def deepspeed_offload(self) -> bool: strategy = self.trainer.strategy if isinstance(strategy, DeepSpeedStrategy): cfg = strategy.config["zero_optimization"] return cfg.get("offload_optimizer") or cfg.get("offload_param") return False def forward(self, idx): args = self.args B, T = idx.size() assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." x = self.emb(idx) x_emb = x if args.dropout > 0: x = self.drop0(x) if args.tiny_att_dim > 0: for block in self.blocks: if args.grad_cp == 1: if args.lora: x = torch_checkpoint(block, x, x_emb, use_reentrant=False) else: x = deepspeed.checkpointing.checkpoint(block, x, x_emb) else: x = block(x, x_emb) else: for block in self.blocks: if args.grad_cp == 1: if args.lora: x = torch_checkpoint(block, x, x_emb, use_reentrant=False) else: x = deepspeed.checkpointing.checkpoint(block, x) else: x = block(x) x = self.ln_out(x) if args.head_qk > 0: q = self.head_q(x)[:, :T, :] k = self.head_k(x)[:, :T, :] c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) if "32" in os.environ["RWKV_FLOAT_MODE"]: c = c @ F.one_hot(idx, num_classes=args.vocab_size) elif os.environ["RWKV_FLOAT_MODE"] == "fp16": c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() x = self.head(x) + c else: x = self.head(x) return x def training_step(self, batch, batch_idx): args = self.args if args.my_qa_mask != 1: idx, targets = batch logits = self(idx) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) # if '0' in os.environ["RWKV_MY_TESTING"]: # print('logits', logits) # torch.set_printoptions(threshold=10000) # print('idx', idx) # exit(0) else: idx, targets, mask = batch mask = mask.view(-1) sum_mask = torch.sum(mask).item() # if sum_mask == 0: # return torch.tensor([0.0], requires_grad=True) logits = self(idx) if sum_mask == mask.shape[0]: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1) ) # print('rank', self.global_rank, 'loss', loss.item()) else: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none" ) # loss_raw = loss loss = torch.sum(loss * mask) / sum_mask # torch.set_printoptions(threshold=10000) # if True: #self.global_rank == 1: # tmp = '' # sss = 0 # ccc = 0 # for i in range(mask.shape[0]): # if mask[i] > 0: # tmp += str(idx.view(-1)[i].item()) + ',' # sss += loss_raw.view(-1)[i].float().item() # ccc += 1 # print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx) return L2Wrap.apply(loss, logits) def training_step_end(self, batch_parts): if pl.__version__[0] != "2": all = self.all_gather(batch_parts) if self.trainer.is_global_zero: self.trainer.my_loss_all = all def generate_init_weight(self): print( f""" ############################################################################ # # Init model weight (slow for large models)... # ############################################################################ """ ) m = {} for n in self.state_dict(): p = self.state_dict()[n] shape = p.shape gain = 1.0 scale = 1.0 if ( "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or ".mask." in n ): if "ln_x.weight" in n: layer_scale = (1 + int(n.split(".")[1])) / self.args.n_layer m[n] = (p * 0.0) + (layer_scale**0.7) else: m[n] = p else: if n == "emb.weight": scale = -1 * self.args.lr_init else: if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) zero = [ ".att.output.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", ".oo.", ".rr.", ] for kk in zero: if kk in n: scale = 0 if n == "head.weight": scale = 0.5 if "head_k." in n: scale = 0.1 if "head_q." in n: scale = 0 print( f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}" ) if self.args.accelerator.upper() == "GPU": m[n] = torch.empty((shape[0], shape[1]), device="cuda") else: m[n] = torch.empty((shape[0], shape[1])) if scale == 0: nn.init.zeros_(m[n]) elif scale < 0: nn.init.uniform_(m[n], a=scale, b=-scale) else: nn.init.orthogonal_(m[n], gain=gain * scale) m[n] = m[n].cpu() if os.environ["RWKV_FLOAT_MODE"] == "fp16": m[n] = m[n].half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": m[n] = m[n].bfloat16() # if n == "emb.weight": # print(m[n]) gc.collect() torch.cuda.empty_cache() return m