fix_tokens
This commit is contained in:
parent
7bc8da2e29
commit
cf16e54463
@ -10,6 +10,7 @@ from routes import state_cache
|
|||||||
|
|
||||||
END_OF_TEXT = 0
|
END_OF_TEXT = 0
|
||||||
END_OF_LINE = 187
|
END_OF_LINE = 187
|
||||||
|
END_OF_LINE_DOUBLE = 535
|
||||||
|
|
||||||
|
|
||||||
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
||||||
@ -77,6 +78,12 @@ The following is a coherent verbose detailed conversation between a girl named {
|
|||||||
except HTTPException:
|
except HTTPException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
|
||||||
|
def fix_tokens(tokens):
|
||||||
|
if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
|
||||||
|
tokens = tokens[:-1] + [END_OF_LINE, END_OF_LINE]
|
||||||
|
return tokens
|
||||||
|
|
||||||
def run_rnn(self, _tokens: List[str], newline_adj: int = 0):
|
def run_rnn(self, _tokens: List[str], newline_adj: int = 0):
|
||||||
tokens = [int(x) for x in _tokens]
|
tokens = [int(x) for x in _tokens]
|
||||||
self.model_tokens += tokens
|
self.model_tokens += tokens
|
||||||
@ -112,7 +119,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
|||||||
logits = copy.deepcopy(cache["logits"])
|
logits = copy.deepcopy(cache["logits"])
|
||||||
|
|
||||||
if delta_prompt != "":
|
if delta_prompt != "":
|
||||||
logits = self.run_rnn(self.pipeline.encode(delta_prompt))
|
logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(delta_prompt)))
|
||||||
try:
|
try:
|
||||||
state_cache.add_state(
|
state_cache.add_state(
|
||||||
state_cache.AddStateBody(
|
state_cache.AddStateBody(
|
||||||
|
Loading…
Reference in New Issue
Block a user