Skip to content

Commit

Permalink
Merge pull request #113 from RamiAwar/S6.3-openai-key-validation
Browse files Browse the repository at this point in the history
S6.3 openai key validation
  • Loading branch information
RamiAwar authored Apr 6, 2024
2 parents 9f2936e + 99330a7 commit d68ff3a
Show file tree
Hide file tree
Showing 16 changed files with 159 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Add preferred openai model field
Revision ID: ffda068eb4b2
Revises: 32035ba6ade3
Create Date: 2024-04-01 22:21:50.205103
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "ffda068eb4b2"
down_revision: Union[str, None] = "32035ba6ade3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.add_column(sa.Column("preferred_openai_model", sa.String(), nullable=True))

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.drop_column("preferred_openai_model")

# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion text2sql-backend/context_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def __init__(
connection: Connection,
sql_database: CustomSQLDatabase,
openai_api_key: str,
model: str,
context_dict: Optional[dict[str, str]] = None,
context_str: Optional[str] = None,
model: Optional[str] = "gpt-4",
embedding_model: Optional[str] = "text-embedding-ada-002",
temperature: Optional[float] = 0.0,
):
Expand Down
2 changes: 2 additions & 0 deletions text2sql-backend/dataline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ class Config(BaseSettings):
sample_netflix_path: str = str(Path(__file__).parent / "samples" / "netflix.sqlite3")
sample_titanic_path: str = str(Path(__file__).parent / "samples" / "titanic.sqlite3")

default_model: str = "gpt-4"


config = Config()
1 change: 1 addition & 0 deletions text2sql-backend/dataline/models/user/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ class UserModel(DBModel, UUIDMixin):
__tablename__ = "user"
name: Mapped[str | None] = mapped_column("name", String(100), nullable=True)
openai_api_key: Mapped[str | None] = mapped_column("openai_api_key", String, nullable=True)
preferred_openai_model: Mapped[str | None] = mapped_column("preferred_openai_model", String, nullable=True)
22 changes: 20 additions & 2 deletions text2sql-backend/dataline/models/user/schema.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field
import openai
from pydantic import BaseModel, ConfigDict, Field, field_validator

from dataline.config import config


class UserUpdateIn(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=250)
openai_api_key: Optional[str] = Field(None, min_length=1)
openai_api_key: Optional[str] = Field(None, min_length=4, pattern=r"^sk-(\w|\d)+$")
preferred_openai_model: Optional[str] = None

@field_validator("openai_api_key")
@classmethod
def check_openai_key(cls, openai_key: str) -> str:
client = openai.OpenAI(api_key=openai_key)
try:
required_models = [config.default_model, "gpt-3.5-turbo"]
models = client.models.list()
if not any(model.id == required_model for model in models for required_model in required_models):
raise ValueError(f"Must have access to at least one of {required_models}")
except openai.AuthenticationError as e:
raise ValueError("Invalid OpenAI Key") from e
return openai_key


class UserOut(BaseModel):
model_config = ConfigDict(from_attributes=True)

name: Optional[str] = None
openai_api_key: Optional[str] = None
preferred_openai_model: Optional[str] = None


class AvatarOut(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions text2sql-backend/dataline/repositories/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class UserCreate(BaseModel):

name: Optional[str] = None
openai_api_key: Optional[str] = None
preferred_openai_model: Optional[str] = None


class UserUpdate(UserCreate): ...
Expand Down
30 changes: 30 additions & 0 deletions text2sql-backend/dataline/services/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
from uuid import uuid4

from fastapi import Depends, UploadFile
import openai

from dataline.config import config
from dataline.models.media.model import MediaModel
from dataline.models.user.schema import UserOut, UserUpdateIn
from dataline.repositories.base import AsyncSession, NotFoundError
from dataline.repositories.media import MediaCreate, MediaRepository
from dataline.repositories.user import UserCreate, UserRepository, UserUpdate


def model_exists(openai_api_key: str, model: str):
models = openai.OpenAI(api_key=openai_api_key).models.list()
return model in {model.id for model in models}


class SettingsService:
media_repo: MediaRepository
user_repo: UserRepository
Expand Down Expand Up @@ -62,10 +69,26 @@ async def update_user_info(self, session: AsyncSession, data: UserUpdateIn) -> U
if user_info is None:
# Create user with data
user_create = UserCreate.model_construct(**data.model_dump(exclude_none=True))
if user_create.openai_api_key and user_create.preferred_openai_model is None:
user_create.preferred_openai_model = (
config.default_model
if model_exists(user_create.openai_api_key, config.default_model)
else "gpt-3.5-turbo"
)
user = await self.user_repo.create(session, user_create)
else:
# Update user with data
user_update = UserUpdate.model_construct(**data.model_dump(exclude_none=True))
if user_update.openai_api_key:
key_to_check = user_update.openai_api_key
model_to_check = (
user_update.preferred_openai_model or user_info.preferred_openai_model or config.default_model
)
if not model_exists(key_to_check, model_to_check):
raise Exception(f"model {model_to_check} not accessible with current key")
elif user_update.preferred_openai_model and user_info.openai_api_key:
if not model_exists(user_info.openai_api_key, user_update.preferred_openai_model):
raise Exception(f"model {user_update.preferred_openai_model} not accessible with current key")
user = await self.user_repo.update_by_id(session, record_id=user_info.id, data=user_update)

return UserOut.model_validate(user)
Expand All @@ -83,3 +106,10 @@ async def get_openai_api_key(self, session: AsyncSession) -> str:
raise Exception("User or OpenAI key not setup. Please setup your application.")

return user_info.openai_api_key

async def get_preferred_model(self, session: AsyncSession) -> str:
user_info = await self.user_repo.get_one_or_none(session)
if user_info is None or not user_info.openai_api_key:
raise Exception("User or OpenAI key not setup. Please setup your application.")

return user_info.preferred_openai_model or config.default_model
4 changes: 2 additions & 2 deletions text2sql-backend/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import AsyncIterator, Awaitable, Callable, Literal
from typing import AsyncIterator

import openai

Expand All @@ -10,7 +10,7 @@ class ChatLLM:
def __init__(
self,
openai_api_key: str,
model: Literal["gpt-4"] = "gpt-4",
model: str,
temperature: float = 0.0,
):
self.model = model
Expand Down
6 changes: 4 additions & 2 deletions text2sql-backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ async def execute_sql(
raise HTTPException(status_code=404, detail="Invalid connection_id")

openai_key = await settings_service.get_openai_api_key(session)
query_service = QueryService(connection, openai_api_key=openai_key)
preferred_model = await settings_service.get_preferred_model(session)
query_service = QueryService(connection, openai_api_key=openai_key, model_name=preferred_model)

# Execute query
data = query_service.run_sql(sql)
Expand Down Expand Up @@ -200,7 +201,8 @@ async def query(
raise HTTPException(status_code=404, detail="Invalid connection_id")

openai_key = await settings_service.get_openai_api_key(session)
query_service = QueryService(connection=connection, openai_api_key=openai_key, model_name="gpt-3.5-turbo")
preferred_model = await settings_service.get_preferred_model(session)
query_service = QueryService(connection=connection, openai_api_key=openai_key, model_name=preferred_model)
response = await query_service.query(query, conversation_id=conversation_id)
unsaved_results = results_from_query_response(response)

Expand Down
4 changes: 2 additions & 2 deletions text2sql-backend/query_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
import logging
from typing import Literal, Optional
from typing import Optional

import openai
from openai.types.chat import ChatCompletionChunk
Expand All @@ -17,8 +17,8 @@ def __init__(
self,
dsn: str,
openai_api_key: str,
model: str,
examples: Optional[dict] = None,
model: Literal["gpt-4"] = "gpt-4",
embedding_model: Optional[str] = "text-embedding-ada-002",
temperature: float = 0.0,
):
Expand Down
6 changes: 4 additions & 2 deletions text2sql-backend/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,17 @@ def __init__(
self,
connection: Connection,
openai_api_key: str,
model_name: str = "gpt-4",
model_name: str,
temperature: float = 0.0,
) -> None:
self.session = connection
self.engine = create_engine(connection.dsn)
self.insp = inspect(self.engine)
self.table_names = self.insp.get_table_names()
self.sql_db = CustomSQLDatabase(self.engine, include_tables=self.table_names)
self.context_builder = CustomSQLContextContainerBuilder(connection, self.sql_db, openai_api_key=openai_api_key)
self.context_builder = CustomSQLContextContainerBuilder(
connection, self.sql_db, openai_api_key=openai_api_key, model=model_name
)
self.query_manager = SQLQueryManager(
dsn=connection.dsn, openai_api_key=openai_api_key, model=model_name, temperature=temperature
)
Expand Down
29 changes: 23 additions & 6 deletions text2sql-backend/tests/api/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from openai.resources.models import Models as OpenAIModels
from unittest.mock import MagicMock, patch


logger = logging.getLogger(__name__)
Expand All @@ -21,6 +23,7 @@ async def test_update_user_info_name(client: TestClient) -> None:
"data": {
"name": "John",
"openai_api_key": None,
"preferred_openai_model": None,
},
}

Expand Down Expand Up @@ -55,24 +58,38 @@ async def test_update_user_info_invalid_openai_key(client: TestClient) -> None:


@pytest.mark.asyncio
async def test_update_user_info_valid_openai_key(client: TestClient) -> None:
@patch.object(OpenAIModels, "list")
async def test_update_user_info_valid_openai_key(mock_openai_model_list: MagicMock, client: TestClient) -> None:
mock_model = MagicMock()
mock_model.id = "gpt-4"
mock_openai_model_list.return_value = [mock_model]
openai_key = "sk-Mioanowida"
user_in = {"openai_api_key": openai_key}
response = client.patch("/settings/info", json=user_in)
assert response.status_code == 200
assert response.status_code == 200, response.json()
assert response.json()["data"]["openai_api_key"] == openai_key
mock_openai_model_list.assert_called()


@pytest.mark.asyncio
async def test_update_user_info_extra_fields_ignored(client: TestClient) -> None:
user_in = {"name": "John", "openai_api_key": "key", "extra": "extra"}
@patch.object(OpenAIModels, "list")
async def test_update_user_info_extra_fields_ignored(mock_openai_model_list: MagicMock, client: TestClient) -> None:
mock_model = MagicMock()
mock_model.id = "gpt-4"
mock_openai_model_list.return_value = [mock_model]
user_in = {"name": "John", "openai_api_key": "sk-1234", "extra": "extra"}
response = client.patch("/settings/info", json=user_in)
assert response.status_code == 200
assert "extra" not in response.json()["data"]
mock_openai_model_list.assert_called()


@pytest_asyncio.fixture
async def user_info(client: TestClient) -> dict[str, str]:
@patch.object(OpenAIModels, "list")
async def user_info(mock_openai_model_list: MagicMock, client: TestClient) -> dict[str, str]:
mock_model = MagicMock()
mock_model.id = "gpt-4"
mock_openai_model_list.return_value = [mock_model]
user_in = {
"name": "John",
"openai_api_key": "sk-asoiasdfl",
Expand All @@ -91,7 +108,7 @@ async def test_get_info(client: TestClient, user_info: dict[str, str]) -> None:

# Check that the response body contains the expected data
# Replace this with your actual assertions based on your application's logic
assert response.json()["data"] == user_info
assert response.json()["data"] == {**user_info, "preferred_openai_model": "gpt-4"}


@pytest.mark.asyncio
Expand Down
2 changes: 1 addition & 1 deletion text2sql-frontend/src/components/Home/Home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export const Home = () => {
</div>
) : (
<div>
<OpenAIKeyPopup></OpenAIKeyPopup>
<OpenAIKeyPopup />
</div>
);
};
13 changes: 13 additions & 0 deletions text2sql-frontend/src/components/Settings/OpenAIKeyPopup.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ export function OpenAIKeyPopup() {
onKeyUp={handleKeyPress}
/>
</AlertBody>
<AlertDescription>
<p className="text-xs">
* Please update your API key with{" "}
<a
className="underline"
target="_blank"
href="https://help.openai.com/en/articles/8867743-assign-api-key-permissions"
>
full permissions{" "}
</a>
to use DataLine.
</p>
</AlertDescription>
<AlertActions>
<Button onClick={() => saveApiKey()}>Continue</Button>
</AlertActions>
Expand Down
11 changes: 11 additions & 0 deletions text2sql-frontend/src/components/Settings/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,17 @@ export default function Account() {
onChange={setApiKey}
/>
</div>
<p className="text-xs text-gray-400 pt-2">
* Please update your API key with{" "}
<a
className="underline"
target="_blank"
href="https://help.openai.com/en/articles/8867743-assign-api-key-permissions"
>
full permissions{" "}
</a>
to use DataLine.
</p>
</div>
</div>
<div className="mt-8 flex">
Expand Down
16 changes: 9 additions & 7 deletions text2sql-frontend/src/components/Settings/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { api } from "@/api";
import { isAxiosError } from "axios";
import { enqueueSnackbar } from "notistack";

export async function updateName(name: string | null) {
Expand All @@ -15,19 +16,20 @@ export async function updateName(name: string | null) {
}

export async function updateApiKey(apiKey: string | null): Promise<boolean> {
const invalidKeyMessage = "Invalid OpenAI API key.";
if (apiKey === null || apiKey === "" || !apiKey.startsWith("sk-")) {
// TODO: Show error banner: Invalid OpenAI API key
enqueueSnackbar({
variant: "error",
message: "Invalid OpenAI API key.",
});
enqueueSnackbar({ variant: "error", message: invalidKeyMessage });
return false;
}
try {
await api.updateUserInfo({ openai_api_key: apiKey });
return true;
} catch {
enqueueSnackbar({ variant: "error", message: "Error updating API key" });
} catch (exception) {
if (isAxiosError(exception) && exception.response?.status === 422) {
enqueueSnackbar({ variant: "error", message: invalidKeyMessage });
} else {
enqueueSnackbar({ variant: "error", message: "Error updating API key" });
}
return false;
}
}

0 comments on commit d68ff3a

Please sign in to comment.