|
|
from typing import Generic, List, Optional, TypeVar |
|
|
from functools import partial |
|
|
from pydantic import BaseModel |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from fastapi import FastAPI |
|
|
import numpy |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import ORJSONResponse |
|
|
|
|
|
MODEL = SentenceTransformer("all-mpnet-base-v2") |
|
|
|
|
|
def cache(func): |
|
|
inner_cache = dict() |
|
|
def inner(sentences: List[str]): |
|
|
if len(sentences) == 0: |
|
|
return [] |
|
|
not_in_cache = list(filter(lambda s: s not in inner_cache.keys(), sentences)) |
|
|
if len(not_in_cache) > 0: |
|
|
processed_sentences = func(list(not_in_cache)) |
|
|
for sentence, embedding in zip(not_in_cache, processed_sentences): |
|
|
inner_cache[sentence] = embedding |
|
|
return [inner_cache[s] for s in sentences] |
|
|
return inner |
|
|
|
|
|
@cache |
|
|
def _encode(sentences: List[str]): |
|
|
embeddings = MODEL.encode(sentences, normalize_embeddings=True, batch_size=2, show_progress_bar=True) |
|
|
array = [numpy.around(a, 3).tolist() for a in embeddings] |
|
|
return array |
|
|
|
|
|
class EmbedReq(BaseModel): |
|
|
sentences: List[str] |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.post("/embed", response_class=ORJSONResponse) |
|
|
def embed(embed: EmbedReq): |
|
|
return _encode(embed.sentences) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |