add /file-to-text api

This commit is contained in:
josc146 2023-10-25 17:14:33 +08:00
parent 4a18696686
commit df9e1f408e
2 changed files with 77 additions and 2 deletions

View File

@ -2,7 +2,7 @@ import time
start_time = time.time()
import setuptools # avoid warnings
import setuptools # avoid warnings
import os
import sys
import argparse
@ -20,7 +20,7 @@ from utils.rwkv import *
from utils.torch import *
from utils.ngrok import *
from utils.log import log_middleware
from routes import completion, config, state_cache, midi, misc
from routes import completion, config, state_cache, midi, misc, file_process
import global_var
@ -43,6 +43,7 @@ app.add_middleware(
app.include_router(completion.router)
app.include_router(config.router)
app.include_router(midi.router)
app.include_router(file_process.router)
app.include_router(misc.router)
app.include_router(state_cache.router)

View File

@ -0,0 +1,74 @@
import os
from fastapi import (
APIRouter,
HTTPException,
status,
Depends,
File,
UploadFile,
)
from pydantic import BaseModel
from typing import Iterator
router = APIRouter()
class FileToTextParams(BaseModel):
file_name: str
file_encoding: str = "utf-8"
@router.post("/file-to-text", tags=["File Process"])
async def file_to_text(
params: FileToTextParams = Depends(), file_data: UploadFile = File(...)
):
from langchain.schema import Document
from langchain.document_loaders.blob_loaders import Blob
# from langchain
def parse_text(blob: Blob) -> Iterator[Document]:
yield Document(page_content=blob.as_string(), metadata={"source": blob.source})
# from langchain
def parse_pdf(blob: Blob) -> Iterator[Document]:
import fitz
with blob.as_bytes_io() as stream:
doc = fitz.Document(stream=stream)
yield from [
Document(
page_content=page.get_text(),
metadata=dict(
{
"source": blob.source,
"file_path": blob.source,
"page": page.number,
"total_pages": len(doc),
},
**{
k: doc.metadata[k]
for k in doc.metadata
if type(doc.metadata[k]) in [str, int]
},
),
)
for page in doc
]
file_parsers = {".txt": parse_text, ".pdf": parse_pdf}
file_ext = os.path.splitext(params.file_name)[-1]
if file_ext not in file_parsers:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "file type not supported")
pages: Iterator[Document] = file_parsers[file_ext](
Blob.from_data(
await file_data.read(),
encoding=params.file_encoding,
path=params.file_name,
)
)
return {"pages": pages}