move state cache to memory (todo: state cache db)
This commit is contained in:
		
							parent
							
								
									b63370928d
								
							
						
					
					
						commit
						b41a2e7039
					
				@ -73,6 +73,7 @@ body.json:
 | 
			
		||||
- [x] CUDA operator int8 acceleration
 | 
			
		||||
- [ ] macOS support
 | 
			
		||||
- [ ] Linux support
 | 
			
		||||
- [ ] Local State Cache DB
 | 
			
		||||
 | 
			
		||||
## Related Repositories:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -73,6 +73,7 @@ body.json:
 | 
			
		||||
- [x] CUDA算子int8提速
 | 
			
		||||
- [ ] macOS支持
 | 
			
		||||
- [ ] linux支持
 | 
			
		||||
- [ ] 本地状态缓存数据库
 | 
			
		||||
 | 
			
		||||
## 相关仓库:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ from fastapi import APIRouter, HTTPException, Response, status
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
import gc
 | 
			
		||||
import copy
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
@ -41,10 +42,14 @@ def add_state(body: AddStateBody):
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
			
		||||
 | 
			
		||||
    id = trie.insert(body.prompt)
 | 
			
		||||
    device = body.state[0].device
 | 
			
		||||
    dtrie[id] = {
 | 
			
		||||
        "tokens": copy.deepcopy(body.tokens),
 | 
			
		||||
        "state": copy.deepcopy(body.state),
 | 
			
		||||
        "state": [tensor.cpu() for tensor in body.state]
 | 
			
		||||
        if device != torch.device("cpu")
 | 
			
		||||
        else copy.deepcopy(body.state),
 | 
			
		||||
        "logits": copy.deepcopy(body.logits),
 | 
			
		||||
        "device": device,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return "success"
 | 
			
		||||
@ -77,11 +82,15 @@ def longest_prefix_state(body: LongestPrefixStateBody):
 | 
			
		||||
        pass
 | 
			
		||||
    if id != -1:
 | 
			
		||||
        v = dtrie[id]
 | 
			
		||||
        device = v["device"]
 | 
			
		||||
        return {
 | 
			
		||||
            "prompt": trie[id],
 | 
			
		||||
            "tokens": v["tokens"],
 | 
			
		||||
            "state": v["state"],
 | 
			
		||||
            "state": [tensor.to(device) for tensor in v["state"]]
 | 
			
		||||
            if device != torch.device("cpu")
 | 
			
		||||
            else v["state"],
 | 
			
		||||
            "logits": v["logits"],
 | 
			
		||||
            "device": device,
 | 
			
		||||
        }
 | 
			
		||||
    else:
 | 
			
		||||
        return {"prompt": "", "tokens": [], "state": None, "logits": None}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user