Skip to content

Commit

Permalink
Merge pull request #8 from getzep/improve/optional-model-load
Browse files Browse the repository at this point in the history
improve/optional embedding; fix ComputeDevice selection; add mps
  • Loading branch information
danielchalef authored Aug 18, 2023
2 parents ef671ac + 38fa597 commit 19c6863
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 29 deletions.
40 changes: 19 additions & 21 deletions app/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict

from fastapi import Depends, FastAPI, status
from fastapi.responses import ORJSONResponse
from starlette.responses import PlainTextResponse, RedirectResponse
Expand All @@ -16,55 +18,46 @@


@app.on_event("startup")
def startup_event():
def startup_event() -> None:
get_embedder()
get_extractor()


@app.get("/healthz", response_model=str, status_code=status.HTTP_200_OK)
def health():
def health() -> PlainTextResponse:
return PlainTextResponse(".")


@app.get("/config")
def config():
def config() -> Dict[str, Any]:
"""Get the current configuration."""
return settings.dict()


@app.get("/", include_in_schema=False)
def docs_redirect():
def docs_redirect() -> RedirectResponse:
return RedirectResponse("/docs")


@app.post("/entities", response_model=EntityResponse)
def extract_entities(
entity_request: EntityRequest,
extractor: SpacyExtractor = Depends(get_extractor),
):
) -> EntityResponse:
"""Extract Named Entities from a batch of Records."""
return extractor.extract_entities(entity_request.texts)


@app.post(
"/embeddings/message",
description="Retained for legacy v0.8.1 and prior support. Will deprecate soon.",
response_class=ORJSONResponse,
)
def embed_message_collection_legacy(
collection: Collection, embedder: Embedder = Depends(get_embedder)
):
"""Embed a Collection of Documents."""
return ORJSONResponse(
embedder.embed(collection, settings.embeddings_messages_model)
)


@app.post("/embeddings/message", response_class=ORJSONResponse)
def embed_message_collection(
collection: Collection, embedder: Embedder = Depends(get_embedder)
):
) -> ORJSONResponse:
"""Embed a Collection of Documents."""
if not settings.embeddings_messages_enabled:
return ORJSONResponse(
{"error": "Message embeddings are not enabled"}, status_code=400
)

return ORJSONResponse(
embedder.embed(collection, settings.embeddings_messages_model)
)
Expand All @@ -73,8 +66,13 @@ def embed_message_collection(
@app.post("/embeddings/document", response_class=ORJSONResponse)
def embed_document_collection(
collection: Collection, embedder: Embedder = Depends(get_embedder)
):
) -> ORJSONResponse:
"""Embed a Collection of Documents."""
if not settings.embeddings_documents_enabled:
return ORJSONResponse(
{"error": "Message embeddings are not enabled"}, status_code=400
)

return ORJSONResponse(
embedder.embed(collection, settings.embeddings_documents_model)
)
3 changes: 3 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
class ComputeDevices(Enum):
cpu = "cpu"
cuda = "cuda"
mps = "mps"


class Settings(BaseSettings):
Expand All @@ -20,6 +21,8 @@ class Settings(BaseSettings):
log_app_name: str = "zep-nlp"
server_port: int
embeddings_device: ComputeDevices
embeddings_messages_enabled: bool
embeddings_documents_enabled: bool
embeddings_messages_model: str
embeddings_documents_model: str
nlp_spacy_model: str
Expand Down
19 changes: 13 additions & 6 deletions app/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from sentence_transformers import SentenceTransformer # type: ignore

from app.config import log, settings
from app.config import ComputeDevices, log, settings
from app.embedding_models import Collection, Document


Expand All @@ -15,17 +15,24 @@ class Embedder:

def __init__(self):
device = "cpu"
if settings.embeddings_device == "cuda":
if settings.embeddings_device == ComputeDevices.cuda:
log.info("Configured for CUDA")
if torch.cuda.is_available():
device = "cuda"
else:
log.warning("Configured for CUDA but CUDA not available, using CPU")
elif settings.embeddings_device == ComputeDevices.mps:
log.info("Configured for MPS")
if torch.backends.mps.is_available():
device = "mps"
else:
log.warning("Configured for MPS but MPS not available, using CPU")

required_models = [
settings.embeddings_messages_model,
settings.embeddings_documents_model,
]
required_models: List[str] = []
if settings.embeddings_messages_enabled:
required_models.append(settings.embeddings_messages_model)
if settings.embeddings_documents_enabled:
required_models.append(settings.embeddings_documents_model)

models: dict[str, Any] = {}
for model_name in required_models:
Expand Down
4 changes: 3 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ embeddings:
# and torch with CUDA support
device: cpu
messages:
enabled: true
model: all-MiniLM-L6-v2
documents:
enabled: true
model: all-MiniLM-L6-v2
# This is the recommended, moderate-memory model for document embeddings
# model: multi-qa-mpnet-base-dot-v1
nlp:
spacy_model: en_core_web_sm
spacy_model: en_core_web_sm
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "nlp-server"
version = "0.3.0"
version = "0.4.0"
description = "Originated from spacy-cookiecutter by Microsoft"
authors = ["Daniel Chalef <131175+danielchalef@users.noreply.github.com>"]
license = "MIT"
Expand Down

0 comments on commit 19c6863

Please sign in to comment.