Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Backend: mock unit tests #908

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/workflows/backend_integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,47 @@ on:
jobs:
pytest:
permissions: write-all
environment: development
# environment: development
runs-on: ubuntu-latest

steps:
- name: Checkout repo
uses: actions/checkout@v3

- uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'

- name: Install poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
virtualenvs-path: .venv
installer-parallel: true

- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --with dev --no-interaction --no-root

- name: Setup test DB container
run: make test-db

- name: Test with pytest
if: github.actor != 'dependabot[bot]'
run: |
make run-integration-tests
env:
PYTHONPATH: src

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4.0.1
with:
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/backend_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,41 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@v3

- uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'

- name: Install poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
virtualenvs-path: .venv
installer-parallel: true

- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --with dev --no-interaction --no-root

- name: Setup test DB container
run: make test-db

- name: Test with pytest
if: github.actor != 'dependabot[bot]'
run: |
make run-unit-tests-debug
env:
PYTHONPATH: src

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4.0.1
with:
Expand Down
24 changes: 12 additions & 12 deletions src/backend/model_deployments/azure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, AsyncGenerator, Dict, List
from typing import Any, AsyncGenerator

import cohere

Expand Down Expand Up @@ -43,33 +43,33 @@ def __init__(self, **kwargs: Any):
base_url=self.chat_endpoint_url, api_key=self.api_key
)

@classmethod
def name(cls) -> str:
@staticmethod
def name() -> str:
return "Azure"

@classmethod
def env_vars(cls) -> List[str]:
@staticmethod
def env_vars() -> list[str]:
return [AZURE_API_KEY_ENV_VAR, AZURE_CHAT_URL_ENV_VAR]

@classmethod
def rerank_enabled(cls) -> bool:
@staticmethod
def rerank_enabled() -> bool:
return False

@classmethod
def list_models(cls) -> List[str]:
def list_models(cls) -> list[str]:
if not cls.is_available():
return []

return cls.DEFAULT_MODELS

@classmethod
def is_available(cls) -> bool:
@staticmethod
def is_available() -> bool:
return (
AzureDeployment.default_api_key is not None
and AzureDeployment.default_chat_endpoint_url is not None
)

async def invoke_chat(self, chat_request: CohereChatRequest) -> Any:
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs) -> Any:
response = self.client.chat(
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
)
Expand All @@ -86,6 +86,6 @@ async def invoke_chat_stream(
yield to_dict(event)

async def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], ctx: Context
self, query: str, documents: list[str], ctx: Context, **kwargs
) -> Any:
return None
26 changes: 13 additions & 13 deletions src/backend/model_deployments/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, List
from typing import Any

from backend.config.settings import Settings
from backend.schemas.cohere_chat import CohereChatRequest
Expand All @@ -25,32 +25,32 @@ def __init__(self, db_id=None, **kwargs: Any):
def id(cls) -> str:
return cls.db_id if cls.db_id else cls.name().replace(" ", "_").lower()

@classmethod
@staticmethod
@abstractmethod
def name(cls) -> str: ...
def name() -> str: ...

@classmethod
@staticmethod
@abstractmethod
def env_vars(cls) -> List[str]: ...
def env_vars() -> list[str]: ...

@classmethod
@staticmethod
@abstractmethod
def rerank_enabled(cls) -> bool: ...
def rerank_enabled() -> bool: ...

@classmethod
@abstractmethod
def list_models(cls) -> List[str]: ...
def list_models(cls) -> list[str]: ...

@classmethod
@staticmethod
@abstractmethod
def is_available(cls) -> bool: ...
def is_available() -> bool: ...

@classmethod
def is_community(cls) -> bool:
return False

@classmethod
def config(cls) -> Dict[str, Any]:
def config(cls) -> dict[str, Any]:
config = Settings().get(f"deployments.{cls.id()}")
config_dict = {} if not config else dict(config)
for key, value in config_dict.items():
Expand Down Expand Up @@ -79,9 +79,9 @@ async def invoke_chat(
@abstractmethod
async def invoke_chat_stream(
self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any
) -> AsyncGenerator[Any, Any]: ...
) -> Any: ...

@abstractmethod
async def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any
self, query: str, documents: list[str], ctx: Context, **kwargs: Any
) -> Any: ...
24 changes: 12 additions & 12 deletions src/backend/model_deployments/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, AsyncGenerator, Dict, List
from typing import Any, AsyncGenerator

import cohere

Expand Down Expand Up @@ -42,40 +42,40 @@ def __init__(self, **kwargs: Any):
),
)

@classmethod
def name(cls) -> str:
@staticmethod
def name() -> str:
return "Bedrock"

@classmethod
def env_vars(cls) -> List[str]:
@staticmethod
def env_vars() -> list[str]:
return [
BEDROCK_ACCESS_KEY_ENV_VAR,
BEDROCK_SECRET_KEY_ENV_VAR,
BEDROCK_SESSION_TOKEN_ENV_VAR,
BEDROCK_REGION_NAME_ENV_VAR,
]

@classmethod
def rerank_enabled(cls) -> bool:
@staticmethod
def rerank_enabled() -> bool:
return False

@classmethod
def list_models(cls) -> List[str]:
def list_models(cls) -> list[str]:
if not cls.is_available():
return []

return cls.DEFAULT_MODELS

@classmethod
def is_available(cls) -> bool:
@staticmethod
def is_available() -> bool:
return (
BedrockDeployment.access_key is not None
and BedrockDeployment.secret_access_key is not None
and BedrockDeployment.session_token is not None
and BedrockDeployment.region_name is not None
)

async def invoke_chat(self, chat_request: CohereChatRequest) -> Any:
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
# bedrock accepts a subset of the chat request fields
bedrock_chat_req = chat_request.model_dump(
exclude={"tools", "conversation_id", "model", "stream"}, exclude_none=True
Expand All @@ -101,6 +101,6 @@ async def invoke_chat_stream(
yield to_dict(event)

async def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], ctx: Context
self, query: str, documents: list[str], ctx: Context, **kwargs: Any
) -> Any:
return None
24 changes: 12 additions & 12 deletions src/backend/model_deployments/cohere_platform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any

import cohere
import requests
Expand Down Expand Up @@ -29,20 +29,20 @@ def __init__(self, **kwargs: Any):
)
self.client = cohere.Client(api_key, client_name=self.client_name)

@classmethod
def name(cls) -> str:
@staticmethod
def name() -> str:
return "Cohere Platform"

@classmethod
def env_vars(cls) -> List[str]:
@staticmethod
def env_vars() -> list[str]:
return [COHERE_API_KEY_ENV_VAR]

@classmethod
def rerank_enabled(cls) -> bool:
@staticmethod
def rerank_enabled() -> bool:
return True

@classmethod
def list_models(cls) -> List[str]:
def list_models(cls) -> list[str]:
logger = LoggerFactory().get_logger()
if not CohereDeployment.is_available():
return []
Expand All @@ -64,12 +64,12 @@ def list_models(cls) -> List[str]:
models = response.json()["models"]
return [model["name"] for model in models if model.get("endpoints") and "chat" in model["endpoints"]]

@classmethod
def is_available(cls) -> bool:
@staticmethod
def is_available() -> bool:
return CohereDeployment.api_key is not None

async def invoke_chat(
self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any
self, chat_request: CohereChatRequest, **kwargs: Any
) -> Any:
response = self.client.chat(
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
Expand Down Expand Up @@ -99,7 +99,7 @@ async def invoke_chat_stream(
yield event_dict

async def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any
self, query: str, documents: list[str], ctx: Context, **kwargs: Any
) -> Any:
response = self.client.rerank(
query=query, documents=documents, model=DEFAULT_RERANK_MODEL
Expand Down
22 changes: 11 additions & 11 deletions src/backend/model_deployments/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import json
from typing import Any, AsyncGenerator, Dict, List
from typing import Any, AsyncGenerator

import boto3

Expand Down Expand Up @@ -65,12 +65,12 @@ def __init__(self, **kwargs: Any):
"ContentType": "application/json",
}

@classmethod
def name(cls) -> str:
@staticmethod
def name() -> str:
return "SageMaker"

@classmethod
def env_vars(cls) -> List[str]:
@staticmethod
def env_vars() -> list[str]:
return [
SAGE_MAKER_ACCESS_KEY_ENV_VAR,
SAGE_MAKER_SECRET_KEY_ENV_VAR,
Expand All @@ -79,19 +79,19 @@ def env_vars(cls) -> List[str]:
SAGE_MAKER_ENDPOINT_NAME_ENV_VAR,
]

@classmethod
def rerank_enabled(cls) -> bool:
@staticmethod
def rerank_enabled() -> bool:
return False

@classmethod
def list_models(cls) -> List[str]:
def list_models(cls) -> list[str]:
if not SageMakerDeployment.is_available():
return []

return cls.DEFAULT_MODELS

@classmethod
def is_available(cls) -> bool:
@staticmethod
def is_available() -> bool:
return (
SageMakerDeployment.region_name is not None
and SageMakerDeployment.aws_access_key_id is not None
Expand Down Expand Up @@ -121,7 +121,7 @@ async def invoke_chat_stream(
yield stream_event

async def invoke_rerank(
self, query: str, documents: List[Dict[str, Any]], ctx: Context
self, query: str, documents: list[str], ctx: Context, **kwargs
) -> Any:
return None

Expand Down
Loading
Loading