diff --git a/app/api.py b/app/api.py index ec93f81..58e7538 100644 --- a/app/api.py +++ b/app/api.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + from fastapi import Depends, FastAPI, status from fastapi.responses import ORJSONResponse from starlette.responses import PlainTextResponse, RedirectResponse @@ -16,24 +18,24 @@ @app.on_event("startup") -def startup_event(): +def startup_event() -> None: get_embedder() get_extractor() @app.get("/healthz", response_model=str, status_code=status.HTTP_200_OK) -def health(): +def health() -> PlainTextResponse: return PlainTextResponse(".") @app.get("/config") -def config(): +def config() -> Dict[str, Any]: """Get the current configuration.""" return settings.dict() @app.get("/", include_in_schema=False) -def docs_redirect(): +def docs_redirect() -> RedirectResponse: return RedirectResponse("/docs") @@ -41,30 +43,21 @@ def docs_redirect(): def extract_entities( entity_request: EntityRequest, extractor: SpacyExtractor = Depends(get_extractor), -): +) -> EntityResponse: """Extract Named Entities from a batch of Records.""" return extractor.extract_entities(entity_request.texts) -@app.post( - "/embeddings/message", - description="Retained for legacy v0.8.1 and prior support. Will deprecate soon.", - response_class=ORJSONResponse, -) -def embed_message_collection_legacy( - collection: Collection, embedder: Embedder = Depends(get_embedder) -): - """Embed a Collection of Documents.""" - return ORJSONResponse( - embedder.embed(collection, settings.embeddings_messages_model) - ) - - @app.post("/embeddings/message", response_class=ORJSONResponse) def embed_message_collection( collection: Collection, embedder: Embedder = Depends(get_embedder) -): +) -> ORJSONResponse: """Embed a Collection of Documents.""" + if not settings.embeddings_messages_enabled: + return ORJSONResponse( + {"error": "Message embeddings are not enabled"}, status_code=400 + ) + return ORJSONResponse( embedder.embed(collection, settings.embeddings_messages_model) ) @@ -73,8 +66,13 @@ def embed_message_collection( @app.post("/embeddings/document", response_class=ORJSONResponse) def embed_document_collection( collection: Collection, embedder: Embedder = Depends(get_embedder) -): +) -> ORJSONResponse: """Embed a Collection of Documents.""" + if not settings.embeddings_documents_enabled: + return ORJSONResponse( + {"error": "Message embeddings are not enabled"}, status_code=400 + ) + return ORJSONResponse( embedder.embed(collection, settings.embeddings_documents_model) ) diff --git a/app/config.py b/app/config.py index 1a8abad..579fa2f 100644 --- a/app/config.py +++ b/app/config.py @@ -12,6 +12,7 @@ class ComputeDevices(Enum): cpu = "cpu" cuda = "cuda" + mps = "mps" class Settings(BaseSettings): @@ -20,6 +21,8 @@ class Settings(BaseSettings): log_app_name: str = "zep-nlp" server_port: int embeddings_device: ComputeDevices + embeddings_messages_enabled: bool + embeddings_documents_enabled: bool embeddings_messages_model: str embeddings_documents_model: str nlp_spacy_model: str diff --git a/app/embedder.py b/app/embedder.py index 305c4cd..c07257d 100644 --- a/app/embedder.py +++ b/app/embedder.py @@ -6,7 +6,7 @@ import torch from sentence_transformers import SentenceTransformer # type: ignore -from app.config import log, settings +from app.config import ComputeDevices, log, settings from app.embedding_models import Collection, Document @@ -15,17 +15,24 @@ class Embedder: def __init__(self): device = "cpu" - if settings.embeddings_device == "cuda": + if settings.embeddings_device == ComputeDevices.cuda: log.info("Configured for CUDA") if torch.cuda.is_available(): device = "cuda" else: log.warning("Configured for CUDA but CUDA not available, using CPU") + elif settings.embeddings_device == ComputeDevices.mps: + log.info("Configured for MPS") + if torch.backends.mps.is_available(): + device = "mps" + else: + log.warning("Configured for MPS but MPS not available, using CPU") - required_models = [ - settings.embeddings_messages_model, - settings.embeddings_documents_model, - ] + required_models: List[str] = [] + if settings.embeddings_messages_enabled: + required_models.append(settings.embeddings_messages_model) + if settings.embeddings_documents_enabled: + required_models.append(settings.embeddings_documents_model) models: dict[str, Any] = {} for model_name in required_models: diff --git a/config.yaml b/config.yaml index 5f07590..1468185 100644 --- a/config.yaml +++ b/config.yaml @@ -6,10 +6,12 @@ embeddings: # and torch with CUDA support device: cpu messages: + enabled: true model: all-MiniLM-L6-v2 documents: + enabled: true model: all-MiniLM-L6-v2 # This is the recommended, moderate-memory model for document embeddings # model: multi-qa-mpnet-base-dot-v1 nlp: - spacy_model: en_core_web_sm \ No newline at end of file + spacy_model: en_core_web_sm diff --git a/pyproject.toml b/pyproject.toml index df9945f..254efa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "nlp-server" -version = "0.3.0" +version = "0.4.0" description = "Originated from spacy-cookiecutter by Microsoft" authors = ["Daniel Chalef <131175+danielchalef@users.noreply.github.com>"] license = "MIT"