RWKVType
This commit is contained in:
		
							parent
							
								
									d12a173f39
								
							
						
					
					
						commit
						8764c37b03
					
				@ -1,4 +1,5 @@
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from enum import Enum, auto
 | 
			
		||||
import os
 | 
			
		||||
import pathlib
 | 
			
		||||
import copy
 | 
			
		||||
@ -18,6 +19,12 @@ END_OF_LINE_DOUBLE = 535
 | 
			
		||||
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RWKVType(Enum):
 | 
			
		||||
    Raven = auto()
 | 
			
		||||
    World = auto()
 | 
			
		||||
    Music = auto()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AbstractRWKV(ABC):
 | 
			
		||||
    def __init__(self, model: str, strategy: str, tokens_path: str):
 | 
			
		||||
        from rwkv.model import RWKV as Model  # dynamic import to make RWKV_CUDA_ON work
 | 
			
		||||
@ -29,6 +36,7 @@ class AbstractRWKV(ABC):
 | 
			
		||||
        self.pipeline = PIPELINE(self.model, tokens_path)
 | 
			
		||||
        self.model_state = None
 | 
			
		||||
        self.model_tokens = []
 | 
			
		||||
        self.rwkv_type: RWKVType = None
 | 
			
		||||
 | 
			
		||||
        self.max_tokens_per_generation = 500
 | 
			
		||||
        self.temperature = 1
 | 
			
		||||
@ -343,10 +351,12 @@ class TextRWKV(AbstractRWKV):
 | 
			
		||||
 | 
			
		||||
        self.interface = ":"
 | 
			
		||||
        if "world" in self.name.lower():
 | 
			
		||||
            self.rwkv_type = RWKVType.World
 | 
			
		||||
            self.user = "Question"
 | 
			
		||||
            self.bot = "Answer"
 | 
			
		||||
            self.END_OF_LINE = 11
 | 
			
		||||
        else:
 | 
			
		||||
            self.rwkv_type = RWKVType.Raven
 | 
			
		||||
            self.user = "Bob"
 | 
			
		||||
            self.bot = "Alice"
 | 
			
		||||
            self.END_OF_LINE = 187
 | 
			
		||||
@ -387,7 +397,7 @@ class TextRWKV(AbstractRWKV):
 | 
			
		||||
 | 
			
		||||
    # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
 | 
			
		||||
    def fix_tokens(self, tokens) -> List[int]:
 | 
			
		||||
        if "world" in self.name.lower():
 | 
			
		||||
        if self.rwkv_type == RWKVType.World:
 | 
			
		||||
            return tokens
 | 
			
		||||
        if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
 | 
			
		||||
            tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
 | 
			
		||||
@ -427,7 +437,7 @@ The following is a coherent verbose detailed conversation between a girl named {
 | 
			
		||||
{bot} likes to tell {user} a lot about herself and her opinions. \
 | 
			
		||||
{bot} usually gives {user} kind, helpful and informative advices.\n
 | 
			
		||||
"""
 | 
			
		||||
            if self.user == "Bob"
 | 
			
		||||
            if self.rwkv_type == RWKVType.Raven
 | 
			
		||||
            else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
 | 
			
		||||
            + "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
 | 
			
		||||
        )
 | 
			
		||||
@ -454,6 +464,8 @@ class MusicRWKV(AbstractRWKV):
 | 
			
		||||
        self.top_p = 0.8
 | 
			
		||||
        self.top_k = 8
 | 
			
		||||
 | 
			
		||||
        self.rwkv_type = RWKVType.Music
 | 
			
		||||
 | 
			
		||||
    def adjust_occurrence(self, occurrence: Dict, token: int):
 | 
			
		||||
        for n in occurrence:
 | 
			
		||||
            occurrence[n] *= 0.997  #### decay repetition penalty
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user