Skip to content

Commit

Permalink
Merge pull request #4 from getzep/local_embedding
Browse files Browse the repository at this point in the history
add env switch for embedding. default to off
  • Loading branch information
danielchalef authored Jun 16, 2023
2 parents 26574a0 + 0e4a98d commit 9609485
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ LANGUAGE_MODEL := en_core_web_sm
all: download-language-model format lint test

run:
poetry run python main.py
ENABLE_EMBEDDINGS=true poetry run python main.py

run-dev:
poetry run uvicorn main:app --reload --log-level debug --port 8080
ENABLE_EMBEDDINGS=true poetry run uvicorn main:app --reload --log-level debug --port 8080

docker-build:
DOCKER_BUILDKIT=1 docker build -t $(CONTAINER_NAME) .
Expand All @@ -36,7 +36,7 @@ lint:
poetry run ruff .

test:
poetry run pytest app/tests
ENABLE_EMBEDDINGS=true poetry run pytest app/tests

loadtest:
poetry run locust -f load_test.py --headless --run-time 60 -u 50 --host http://0.0.0.0:8080
ENABLE_EMBEDDINGS=true poetry run locust -f load_test.py --headless --run-time 60 -u 50 --host http://0.0.0.0:8080
14 changes: 11 additions & 3 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# Licensed under the MIT License.
# Heavy modified by Zep

import os

import spacy
import srsly # type: ignore
from fastapi import Body, FastAPI, status
from fastapi import Body, FastAPI, HTTPException, status
from fastapi.responses import ORJSONResponse
from starlette.responses import PlainTextResponse, RedirectResponse

Expand All @@ -13,6 +15,8 @@
from app.entity_extractor import SpacyExtractor
from app.entity_models import Request, Response

ENABLE_EMBEDDINGS = os.getenv("ENABLE_EMBEDDINGS", "false").lower() == "true"

app = FastAPI(
title="zep-nlp-server",
version="0.2",
Expand All @@ -24,7 +28,8 @@
nlp = spacy.load("en_core_web_sm")
extractor = SpacyExtractor(nlp)

embedder = Embedder()
if ENABLE_EMBEDDINGS:
embedder = Embedder()


@app.get("/healthz", response_model=str, status_code=status.HTTP_200_OK)
Expand All @@ -47,4 +52,7 @@ async def extract_entities(body: Request = Body(..., example=example_request)):
@app.post("/embeddings", response_class=ORJSONResponse)
async def embed_collection(collection: Collection):
"""Embed a Collection of Documents."""
return ORJSONResponse(embedder.embed(collection))
if ENABLE_EMBEDDINGS:
return ORJSONResponse(embedder.embed(collection))
else:
raise HTTPException(status_code=400, detail="Embeddings not enabled.")

0 comments on commit 9609485

Please sign in to comment.