diff --git a/python/fastembed-server.py b/python/fastembed-server.py index fa3f7c82b..dd4a7a9c8 100644 --- a/python/fastembed-server.py +++ b/python/fastembed-server.py @@ -2,18 +2,20 @@ from fastembed import TextEmbedding from fastapi import FastAPI from pydantic import BaseModel -model = TextEmbedding("snowflake/snowflake-arctic-embed-xs") +models = {} app = FastAPI() class EmbeddingRequest(BaseModel): model: str - prompt: str + input: str -@app.post("/api/embeddings") +@app.post("/v1/embeddings") def embeddings(request: EmbeddingRequest): - embeddings = next(model.embed(request.prompt)).tolist() - return {"embedding": embeddings} + model = models.get(request.model) or TextEmbedding(request.model) + models[request.model] = model + embeddings = next(model.embed(request.input)).tolist() + return {"data": [{"embedding": embeddings}]} if __name__ == "__main__": import uvicorn