Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Starter messages for new convos #99

Merged
merged 5 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions text2sql-backend/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ cover/
local_settings.py

# Test dbs
*.sqlite3
*.sqlite3-journal
*.db
db.sqlite3
db.sqlite3-journal
test.sqlite3

# Flask stuff:
instance/
Expand Down
6 changes: 5 additions & 1 deletion text2sql-backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@

from dataline.api.connection.router import router as connection_router
from dataline.api.settings.router import router as settings_router
from dataline.repositories.base import NotFoundError
from dataline.repositories.base import NotFoundError, NotUniqueError

logger = logging.getLogger(__name__)


def handle_exceptions(request: Request, e: Exception) -> JSONResponse:
if isinstance(e, NotFoundError):
return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content={"message": e.message})
elif isinstance(e, NotUniqueError):
return JSONResponse(status_code=status.HTTP_409_CONFLICT, content={"message": e.message})

logger.exception(e)
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(e)})
Expand All @@ -35,4 +37,6 @@ def __init__(self) -> None:
self.include_router(connection_router)

# Handle 500s separately to play well with TestClient and allow re-raising in tests
self.add_exception_handler(NotFoundError, handle_exceptions)
self.add_exception_handler(NotUniqueError, handle_exceptions)
self.add_exception_handler(Exception, handle_exceptions)
15 changes: 9 additions & 6 deletions text2sql-backend/dataline/api/connection/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
GetConnectionOut,
TableSchemasOut,
)
from dataline.repositories.base import NotFoundError
from dataline.repositories.base import NotFoundError, NotUniqueError
from dataline.utils import get_sqlite_dsn
from models import StatusType, SuccessResponse, UpdateConnectionRequest
from services import SchemaService
Expand Down Expand Up @@ -49,7 +49,7 @@ def create_db_connection(dsn: str, name: str, is_sample: bool = False) -> Succes
try:
existing_connection = db.get_connection_from_dsn(dsn)
if existing_connection:
return SuccessResponse(status=StatusType.ok, data=existing_connection)
raise NotUniqueError("Connection already exists.")
except NotFoundError:
pass

Expand Down Expand Up @@ -92,17 +92,20 @@ def validate_dsn_format(cls, value: str) -> str:
dsn_regex = r"^[\w\+]+:\/\/[\/\w-]+.*$"

if not re.match(dsn_regex, value):
raise ValueError(
'Invalid DSN format. The expected format is "driver://username:password@host:port/database".'
)
raise ValueError("Invalid DSN format.")

# Simpler way to connect to postgres even though officially deprecated
# This mirrors psql which is a very common way to connect to postgres
if "postgres://" in value:
value = value.replace("postgres://", "postgresql://")

return value


@router.post("/create-sample-db")
async def create_sample_db() -> SuccessResponse[ConnectionOut]:
name = "DVD Rental (Sample)"
dsn = get_sqlite_dsn(config.sample_postgres_path)
dsn = get_sqlite_dsn(config.sample_dvdrental_path)
return create_db_connection(dsn, name, is_sample=True)


Expand Down
4 changes: 3 additions & 1 deletion text2sql-backend/dataline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ class Config(BaseSettings):
sqlite_path: str = str(Path(__file__).parent / "db.sqlite3")
sqlite_echo: bool = False

sample_postgres_path: str = str(Path(__file__).parent / "samples" / "postgres" / "dvd_rental.sqlite3")
sample_dvdrental_path: str = str(Path(__file__).parent / "samples" / "dvd_rental.sqlite3")
sample_netflix_path: str = str(Path(__file__).parent / "samples" / "netflix.sqlite3")
sample_titanic_path: str = str(Path(__file__).parent / "samples" / "titanic.sqlite3")


config = Config()
2 changes: 1 addition & 1 deletion text2sql-backend/dataline/models/conversation/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from uuid import UUID

from sqlalchemy import ForeignKey, Integer, String
from sqlalchemy.orm import mapped_column, Mapped
from sqlalchemy.orm import Mapped, mapped_column

from dataline.models.base import DBModel
from dataline.models.connection import ConnectionModel
Expand Down
Binary file not shown.
Binary file not shown.
7 changes: 4 additions & 3 deletions text2sql-backend/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@

@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection: Any, connection_record: Any) -> None: # type: ignore[misc]
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
if type(dbapi_connection) is sqlite3.Connection: # play well with other DB backends
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()


# Old way of using database - this is a single connection, hard to manage transactions
Expand Down
6 changes: 6 additions & 0 deletions text2sql-backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataline.services.settings import SettingsService
from models import (
ConversationWithMessagesWithResults,
Conversation,
DataResult,
MessageWithResults,
Result,
Expand Down Expand Up @@ -88,6 +89,11 @@ async def delete_conversation(conversation_id: str) -> dict[str, str]:
return {"status": "ok"}


@app.get("/conversation/{conversation_id}")
async def get_conversation(conversation_id: str) -> SuccessResponse[Conversation]:
return SuccessResponse(status=StatusType.ok, data=db.get_conversation(conversation_id))


class ListMessageOut(BaseModel):
messages: list[MessageWithResults]

Expand Down
58 changes: 43 additions & 15 deletions text2sql-backend/tests/api/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import pathlib
from typing import AsyncGenerator

import pytest
import pytest_asyncio
Expand All @@ -22,11 +23,35 @@ async def test_create_sample_db_connection(client: TestClient) -> None:
assert data["is_sample"] is True
assert data["id"]

# TODO: Remove after sqlalchemy migration
# Manual rollback
client.delete(f"/connection/{data['id']}")


@pytest_asyncio.fixture
async def sample_db(client: TestClient) -> Connection:
async def sample_db(client: TestClient) -> AsyncGenerator[Connection, None]:
response = client.post("/create-sample-db")
assert response.status_code == 200
connection = Connection(**response.json()["data"])

# TODO: Remove after sqlalchemy migration
# Manual rollback
yield connection
client.delete(f"/connection/{str(connection.id)}")


@pytest.mark.asyncio
async def test_create_sample_db_connection_twice_409(client: TestClient) -> None:
response = client.post("/create-sample-db")
assert response.status_code == 200
connection = Connection(**response.json()["data"])

response = client.post("/create-sample-db")
return Connection(**response.json()["data"])
assert response.status_code == 409

# TODO: Remove after sqlalchemy migration
# Manual rollback
client.delete(f"/connection/{str(connection.id)}")


@pytest.mark.asyncio
Expand Down Expand Up @@ -74,6 +99,10 @@ async def test_connect_db(client: TestClient) -> None:
# Delete database after tests
pathlib.Path("test.db").unlink(missing_ok=True)

# TODO: Remove after sqlalchemy migration
# Manual rollback
client.delete(f"/connection/{data['id']}")


@pytest.mark.asyncio
async def test_get_table_schemas(client: TestClient, sample_db: Connection) -> None:
Expand Down Expand Up @@ -140,19 +169,6 @@ async def test_update_table_schema_field_description(client: TestClient, example
assert field.description == update_in["description"]


@pytest.mark.asyncio
@pytest.mark.skip(reason="Do not want to deal with this now")
async def test_delete_connection(client: TestClient, sample_db: Connection) -> None:
response = client.delete(f"/connection/{str(sample_db.id)}")

assert response.status_code == 200

# Check if the connection was deleted
response = client.get("/connections")
data = response.json()["data"]
assert len(data["connections"]) == 0


@pytest.mark.asyncio
async def test_update_connection(client: TestClient, sample_db: Connection) -> None:
update_in = {
Expand All @@ -169,3 +185,15 @@ async def test_update_connection(client: TestClient, sample_db: Connection) -> N

# Delete database after tests
pathlib.Path("new.db").unlink(missing_ok=True)


@pytest.mark.asyncio
async def test_delete_connection(client: TestClient, sample_db: Connection) -> None:
response = client.delete(f"/connection/{str(sample_db.id)}")

assert response.status_code == 200

# Check if the connection was deleted
response = client.get("/connections")
data = response.json()["data"]
assert len(data["connections"]) == 0
6 changes: 2 additions & 4 deletions text2sql-backend/tests/api/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest_asyncio
from fastapi.testclient import TestClient

from dataline.repositories.base import NotFoundError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -97,9 +96,8 @@ async def test_get_info(client: TestClient, user_info: dict[str, str]) -> None:

@pytest.mark.asyncio
async def test_get_info_not_found(client: TestClient) -> None:
with pytest.raises(NotFoundError):
response = client.get("/settings/info")
assert response.status_code == 404
response = client.get("/settings/info")
assert response.status_code == 404


FileTuple = tuple[str, tuple[str, BytesIO, str]]
Expand Down
2 changes: 1 addition & 1 deletion text2sql-backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def engine() -> AsyncGenerator[AsyncEngine, None]:
pathlib.Path("test.sqlite3").unlink(missing_ok=True)


@pytest_asyncio.fixture
@pytest_asyncio.fixture(scope="function")
async def session(engine: AsyncEngine, monkeypatch: pytest.MonkeyPatch) -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(engine) as session, session.begin():
# prevent test from committing anything, only flush
Expand Down
1 change: 1 addition & 0 deletions text2sql-frontend/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export type ConnectionResult = {
database: string;
name: string;
dialect: string;
is_sample: boolean;
};
type ConnectResult = ApiResponse<ConnectionResult>;
const createConnection = async (
Expand Down
Loading
Loading