upgrade to rwkv 0.8.26 (state instruct align support)
This commit is contained in:
		
							parent
							
								
									70236df3d1
								
							
						
					
					
						commit
						38b33a7030
					
				@ -1,7 +1,7 @@
 | 
			
		||||
torch
 | 
			
		||||
torchvision
 | 
			
		||||
torchaudio
 | 
			
		||||
rwkv==0.8.25
 | 
			
		||||
rwkv==0.8.26
 | 
			
		||||
langchain==0.0.322
 | 
			
		||||
fastapi==0.109.1
 | 
			
		||||
uvicorn==0.23.2
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
torch
 | 
			
		||||
torchvision
 | 
			
		||||
torchaudio
 | 
			
		||||
rwkv==0.8.25
 | 
			
		||||
rwkv==0.8.26
 | 
			
		||||
langchain==0.0.322
 | 
			
		||||
fastapi==0.109.1
 | 
			
		||||
uvicorn==0.23.2
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										20
									
								
								backend-python/rwkv_pip/model.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										20
									
								
								backend-python/rwkv_pip/model.py
									
									
									
									
										vendored
									
									
								
							@ -488,14 +488,19 @@ class RWKV(MyModule):
 | 
			
		||||
            print_need_newline = False
 | 
			
		||||
 | 
			
		||||
            REAL_TIME_FIRST = False
 | 
			
		||||
            args.time_state = False
 | 
			
		||||
            for x in list(w.keys()):
 | 
			
		||||
                if ".time_faaaa" in x:
 | 
			
		||||
                    REAL_TIME_FIRST = True
 | 
			
		||||
                if ".time_state" in x:
 | 
			
		||||
                    args.time_state = True
 | 
			
		||||
            if REAL_TIME_FIRST:
 | 
			
		||||
                w = {
 | 
			
		||||
                    (
 | 
			
		||||
                        k.replace(".time_faaaa", ".time_first")
 | 
			
		||||
                        if ".time_faaaa" in k
 | 
			
		||||
                    else k: v
 | 
			
		||||
                        else k
 | 
			
		||||
                    ): v
 | 
			
		||||
                    for k, v in w.items()
 | 
			
		||||
                }
 | 
			
		||||
                self.w = w
 | 
			
		||||
@ -631,7 +636,9 @@ class RWKV(MyModule):
 | 
			
		||||
                        torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
                shape = [i for i in w[x].shape if i != 1]
 | 
			
		||||
                if len(shape) > 1:
 | 
			
		||||
                if len(shape) > 2:
 | 
			
		||||
                    shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}"
 | 
			
		||||
                elif len(shape) > 1:
 | 
			
		||||
                    shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}      "
 | 
			
		||||
                else:
 | 
			
		||||
                    shape = f" {str(shape[0]).rjust(5)}            "
 | 
			
		||||
@ -2108,6 +2115,15 @@ class RWKV(MyModule):
 | 
			
		||||
                        state[i * 3 + 0] = torch.zeros(
 | 
			
		||||
                            args.n_embd, dtype=atype, requires_grad=False, device=dev
 | 
			
		||||
                        ).contiguous()
 | 
			
		||||
                        if args.time_state:
 | 
			
		||||
                            state[i * 3 + 1] = (
 | 
			
		||||
                                w[f"blocks.{i}.att.time_state"]
 | 
			
		||||
                                .transpose(1, 2)
 | 
			
		||||
                                .to(dtype=torch.float, device=dev)
 | 
			
		||||
                                .requires_grad_(False)
 | 
			
		||||
                                .contiguous()
 | 
			
		||||
                            )
 | 
			
		||||
                        else:
 | 
			
		||||
                            state[i * 3 + 1] = torch.zeros(
 | 
			
		||||
                                (
 | 
			
		||||
                                    args.n_head,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user