diff --git a/text2sql-backend/.gitignore b/text2sql-backend/.gitignore index d780e059..7895cd44 100644 --- a/text2sql-backend/.gitignore +++ b/text2sql-backend/.gitignore @@ -62,9 +62,11 @@ cover/ # Django stuff: *.log local_settings.py -db.sqlite3 -db.sqlite3-journal -test.sqlite3 + +# Test dbs +*.sqlite3 +*.sqlite3-journal +*.db # Flask stuff: instance/ diff --git a/text2sql-backend/context_builder.py b/text2sql-backend/context_builder.py index a22272a3..66492095 100644 --- a/text2sql-backend/context_builder.py +++ b/text2sql-backend/context_builder.py @@ -7,8 +7,8 @@ from llama_index.indices.struct_store import SQLContextContainerBuilder import db +from dataline.models.connection.schema import Connection from llm import ChatLLM -from models import Connection from sql_wrapper import CustomSQLDatabase from tokenizer import num_tokens_from_string @@ -49,7 +49,7 @@ def __init__( context_str: Optional[str] = None, model: Optional[str] = "gpt-4", embedding_model: Optional[str] = "text-embedding-ada-002", - temperature: Optional[int] = 0.0, + temperature: Optional[float] = 0.0, ): """Initialize params.""" self.connection: Connection = connection diff --git a/text2sql-backend/dataline/api/connection/router.py b/text2sql-backend/dataline/api/connection/router.py new file mode 100644 index 00000000..5734509f --- /dev/null +++ b/text2sql-backend/dataline/api/connection/router.py @@ -0,0 +1,269 @@ +import logging +import re +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Body, HTTPException +from pydantic import BaseModel, Field, field_validator +from sqlalchemy import create_engine +from sqlalchemy.exc import OperationalError + +import db +from dataline.config import config +from dataline.models.connection.schema import ( + ConnectionOut, + GetConnectionOut, + TableSchemasOut, +) +from dataline.repositories.base import NotFoundError +from dataline.utils import get_sqlite_dsn +from models import StatusType, SuccessResponse, UpdateConnectionRequest +from services import SchemaService + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["connections"]) + + +def create_db_connection(dsn: str, name: str, is_sample: bool = False) -> SuccessResponse[ConnectionOut]: + try: + engine = create_engine(dsn) + with engine.connect(): + pass + except OperationalError as exc: + # Try again replacing localhost with host.docker.internal to connect with DBs running in docker + if "localhost" in dsn: + dsn = dsn.replace("localhost", "host.docker.internal") + try: + engine = create_engine(dsn) + with engine.connect(): + pass + except OperationalError as e: + logger.error(e) + raise HTTPException(status_code=404, detail="Failed to connect to database") + else: + logger.error(exc) + raise HTTPException(status_code=404, detail="Failed to connect to database") + + # Check if connection with DSN already exists, then return connection_id + try: + existing_connection = db.get_connection_from_dsn(dsn) + if existing_connection: + return SuccessResponse(status=StatusType.ok, data=existing_connection) + except NotFoundError: + pass + + # Insert connection only if success + dialect = engine.url.get_dialect().name + database = engine.url.database + + if not database: + raise Exception("Invalid DSN. Database name is required.") + + with db.DatabaseManager() as conn: + connection_id = db.create_connection( + conn, + dsn, + database=database, + name=name, + dialect=dialect, + is_sample=is_sample, + ) + + SchemaService.create_or_update_tables(conn, connection_id) + conn.commit() # only commit if all step were successful + + return SuccessResponse( + status=StatusType.ok, + data=ConnectionOut( + id=connection_id, dsn=dsn, database=database, dialect=dialect, name=name, is_sample=is_sample + ), + ) + + +class ConnectRequest(BaseModel): + dsn: str = Field(min_length=3) + name: str + + @field_validator("dsn") + def validate_dsn_format(cls, value: str) -> str: + # Define a regular expression to match the DSN format + # Relaxed to allow for many kinds of DSNs + dsn_regex = r"^[\w\+]+:\/\/[\/\w-]+.*$" + + if not re.match(dsn_regex, value): + raise ValueError( + 'Invalid DSN format. The expected format is "driver://username:password@host:port/database".' + ) + + return value + + +@router.post("/create-sample-db") +async def create_sample_db() -> SuccessResponse[ConnectionOut]: + name = "DVD Rental (Sample)" + dsn = get_sqlite_dsn(config.sample_postgres_path) + return create_db_connection(dsn, name, is_sample=True) + + +@router.post("/connect", response_model_exclude_none=True) +async def connect_db(req: ConnectRequest) -> SuccessResponse[ConnectionOut]: + return create_db_connection(req.dsn, req.name) + + +@router.get("/connection/{connection_id}") +async def get_connection(connection_id: UUID) -> SuccessResponse[GetConnectionOut]: + with db.DatabaseManager() as conn: + return SuccessResponse( + status=StatusType.ok, + data=GetConnectionOut( + connection=db.get_connection(conn, connection_id), + ), + ) + + +class ConnectionsOut(BaseModel): + connections: list[ConnectionOut] + + +@router.get("/connections") +async def get_connections() -> SuccessResponse[ConnectionsOut]: + return SuccessResponse( + status=StatusType.ok, + data=ConnectionsOut( + connections=db.get_connections(), + ), + ) + + +@router.delete("/connection/{connection_id}") +async def delete_connection(connection_id: str) -> SuccessResponse[None]: + with db.DatabaseManager() as conn: + db.delete_connection(conn, connection_id) + return SuccessResponse(status=StatusType.ok) + + +@router.patch("/connection/{connection_id}") +async def update_connection(connection_id: UUID, req: UpdateConnectionRequest) -> SuccessResponse[GetConnectionOut]: + # Try to connect to provided dsn + try: + engine = create_engine(req.dsn) + with engine.connect(): + pass + except OperationalError as e: + logger.error(e) + raise HTTPException(status_code=400, detail="Failed to connect to database") + + # Update connection only if success + dialect = engine.url.get_dialect().name + database = str(engine.url.database) + + db.update_connection( + connection_id=connection_id, + dsn=req.dsn, + database=database, + name=req.name, + dialect=dialect, + ) + + return SuccessResponse( + status=StatusType.ok, + data=GetConnectionOut( + connection=ConnectionOut( + id=connection_id, + dsn=req.dsn, + database=database, + name=req.name, + dialect=dialect, + is_sample=False, # Don't care, just send False + ), + ), + ) + + +@router.get("/connection/{connection_id}/schemas") +async def get_table_schemas(connection_id: UUID) -> SuccessResponse[TableSchemasOut]: + # Check for connection existence + with db.DatabaseManager() as conn: + try: + db.get_connection(conn, connection_id) + except NotFoundError: + raise HTTPException(status_code=404, detail="Invalid connection_id") + + return SuccessResponse( + status=StatusType.ok, + data=TableSchemasOut( + tables=db.get_table_schemas_with_descriptions(connection_id), + ), + ) + + +@router.patch("/schemas/table/{table_id}") +async def update_table_schema_description( + table_id: str, description: Annotated[str, Body(embed=True)] +) -> dict[str, str]: + with db.DatabaseManager() as conn: + db.update_schema_table_description(conn, table_id=table_id, description=description) + conn.commit() + + return {"status": "ok"} + + +@router.patch("/schemas/field/{field_id}") +async def update_table_schema_field_description( + field_id: str, description: Annotated[str, Body(embed=True)] +) -> dict[str, str]: + with db.DatabaseManager() as conn: + db.update_schema_table_field_description(conn, field_id=field_id, description=description) + conn.commit() + + return {"status": "ok"} + + +# TODO: Convert to using services and session +# @router.post("/create-sample-db") +# async def create_sample_db( +# connection_service: ConnectionService = Depends(), +# session: AsyncSession = Depends(get_session), +# ) -> SuccessResponse[dict[str, str]]: +# name = "DVD Rental (Sample)" +# dsn = get_sqlite_dsn(config.sample_postgres_path) +# return await create_db_connection( +# session=session, connection_service=connection_service, dsn=dsn, name=name, is_sample=True +# ) + + +# @router.post("/connect", response_model_exclude_none=True) +# async def connect_db( +# req: ConnectionIn, connection_service: ConnectionService = Depends(), session: AsyncSession = Depends(get_session) +# ) -> SuccessResponse[dict[str, str]]: +# return await create_db_connection( +# session=session, connection_service=connection_service, dsn=req.dsn, name=req.name +# ) + + +# @router.get("/connection/{connection_id}") +# async def get_connection( +# connection_id: UUID, +# connection_service: ConnectionService = Depends(), +# session: AsyncSession = Depends(get_session), +# ) -> SuccessResponse[GetConnectionOut]: +# connection = await connection_service.get_connection(session, connection_id) +# return SuccessResponse( +# status=StatusType.ok, +# data=GetConnectionOut( +# connection=connection, +# ), +# ) + + +# @router.get("/connections") +# async def get_connections( +# connection_service: ConnectionService = Depends(), +# session: AsyncSession = Depends(get_session), +# ) -> SuccessResponse[GetConnectionListOut]: +# connections = await connection_service.get_connections(session) +# return SuccessResponse( +# status=StatusType.ok, +# data=GetConnectionListOut(connections=connections), +# ) diff --git a/text2sql-backend/dataline/models/connection/schema.py b/text2sql-backend/dataline/models/connection/schema.py new file mode 100644 index 00000000..bf6c1018 --- /dev/null +++ b/text2sql-backend/dataline/models/connection/schema.py @@ -0,0 +1,76 @@ +import re +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class ConnectionUpdateIn(BaseModel): + name: str | None = None + dsn: str | None = None + database: str | None = None + dialect: str | None = None + is_sample: bool | None = None + + +class Connection(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: UUID + name: str + dsn: str + database: str + dialect: str + is_sample: bool + + +class ConnectionOut(Connection): + model_config = ConfigDict(from_attributes=True) + + +class ConnectionIn(BaseModel): + dsn: str = Field(min_length=3) + name: str + + @field_validator("dsn") + def validate_dsn_format(cls, value: str) -> str: + # Define a regular expression to match the DSN format + dsn_regex = r"^[\w\+]+:\/\/[\w-]+:\w+@[\w.-]+[:\d]*\/\w+$" + + if not re.match(dsn_regex, value): + raise ValueError( + 'Invalid DSN format. The expected format is "driver://username:password@host:port/database".' + ) + + return value + + +class GetConnectionOut(BaseModel): + connection: ConnectionOut + + +class GetConnectionListOut(BaseModel): + connections: list[ConnectionOut] + + +class TableSchemaField(BaseModel): + id: str + schema_id: str + name: str + type: str + description: str + is_primary_key: Optional[bool] + is_foreign_key: Optional[bool] + linked_table: Optional[str] + + +class TableSchema(BaseModel): + id: str + connection_id: str + name: str + description: str + field_descriptions: list[TableSchemaField] + + +class TableSchemasOut(BaseModel): + tables: list[TableSchema] diff --git a/text2sql-backend/db.py b/text2sql-backend/db.py index 3ea759d1..10767e17 100644 --- a/text2sql-backend/db.py +++ b/text2sql-backend/db.py @@ -3,27 +3,29 @@ from sqlite3 import Cursor from sqlite3.dbapi2 import Connection as SQLiteConnection from typing import Any, List, Literal, Optional -from uuid import uuid4 +from uuid import UUID, uuid4 + +from sqlalchemy import event +from sqlalchemy.engine import Engine from dataline.config import config as dataline_config +from dataline.models.connection.schema import ( + ConnectionOut, + TableSchema, + TableSchemaField, +) from dataline.repositories.base import NotFoundError, NotUniqueError from models import ( - Connection, Conversation, ConversationWithMessagesWithResults, MessageWithResults, Result, - TableSchema, - TableSchemaField, UnsavedResult, ) -from sqlalchemy.engine import Engine -from sqlalchemy import event - @event.listens_for(Engine, "connect") -def set_sqlite_pragma(dbapi_connection, connection_record): +def set_sqlite_pragma(dbapi_connection: Any, connection_record: Any) -> None: # type: ignore[misc] cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() @@ -55,38 +57,41 @@ def create_connection( database: str, name: str = "", dialect: str = "", -) -> str: + is_sample: bool = False, +) -> UUID: # Check if connection_id or dsn already exist - connection_id = str(uuid4()) + # TODO: Doesn't this conflict with SQLAlchemy? + connection_id = uuid4() conn.execute( - "INSERT INTO connections (id, dsn, database, name, dialect) VALUES (?, ?, ?, ?, ?)", - (connection_id, dsn, database, name, dialect), + "INSERT INTO connections (id, dsn, database, name, dialect, is_sample) VALUES (?, ?, ?, ?, ?, ?)", + (str(connection_id), dsn, database, name, dialect, is_sample), ) return connection_id -def get_connection(conn: SQLiteConnection, connection_id: str) -> Connection: +def get_connection(conn: SQLiteConnection, connection_id: UUID) -> ConnectionOut: connection = conn.execute( - "SELECT id, name, dsn, database, dialect FROM connections WHERE id = ?", - (connection_id,), + "SELECT id, name, dsn, database, dialect, is_sample FROM connections WHERE id = ?", + (str(connection_id),), ).fetchone() if not connection: raise NotFoundError("Connection not found") - return Connection( + return ConnectionOut( id=connection[0], name=connection[1], dsn=connection[2], database=connection[3], dialect=connection[4], + is_sample=connection[5], ) -def update_connection(connection_id: str, name: str, dsn: str, database: str, dialect: str) -> bool: +def update_connection(connection_id: UUID, name: str, dsn: str, database: str, dialect: str) -> bool: conn.execute( "UPDATE connections SET name = ?, dsn = ?, database = ?, dialect = ? WHERE id = ?", - (name, dsn, database, dialect, connection_id), + (name, dsn, database, dialect, str(connection_id)), ) conn.commit() return True @@ -98,24 +103,27 @@ def delete_connection(conn: SQLiteConnection, connection_id: str) -> bool: return True -def get_connection_from_dsn(dsn: str) -> Connection: - data = conn.execute("SELECT id, name, dsn, database, dialect FROM connections WHERE dsn = ?", (dsn,)).fetchone() +def get_connection_from_dsn(dsn: str) -> ConnectionOut: + data = conn.execute( + "SELECT id, name, dsn, database, dialect, is_sample FROM connections WHERE dsn = ?", (dsn,) + ).fetchone() if not data: raise NotFoundError("Connection not found") - return Connection( - id=data[0], + return ConnectionOut( + id=UUID(data[0]), name=data[1], dsn=data[2], database=data[3], dialect=data[4], + is_sample=data[5], ) -def get_connections() -> List[Connection]: +def get_connections() -> List[ConnectionOut]: return [ - Connection(id=x[0], name=x[1], dsn=x[2], database=x[3], dialect=x[4]) - for x in conn.execute("SELECT id, name, dsn, database, dialect FROM connections").fetchall() + ConnectionOut(id=x[0], name=x[1], dsn=x[2], database=x[3], dialect=x[4], is_sample=x[5]) + for x in conn.execute("SELECT id, name, dsn, database, dialect, is_sample FROM connections").fetchall() ] @@ -193,7 +201,7 @@ def create_schema_field( return field_id -def get_table_schemas_with_descriptions(connection_id: str) -> List[TableSchema]: +def get_table_schemas_with_descriptions(connection_id: UUID) -> List[TableSchema]: # Select all table schemas for a connection and then join with schema_descriptions to get the field descriptions descriptions = conn.execute( """ @@ -212,7 +220,7 @@ def get_table_schemas_with_descriptions(connection_id: str) -> List[TableSchema] FROM schema_tables INNER JOIN schema_fields ON schema_tables.id = schema_fields.table_id WHERE schema_tables.connection_id = ?""", - (connection_id,), + (str(connection_id),), ).fetchall() # Join all the field descriptions for each table into a list of table schemas @@ -225,7 +233,7 @@ def get_table_schemas_with_descriptions(connection_id: str) -> List[TableSchema] # Return a list of TableSchema objects return [ TableSchema( - connection_id=connection_id, + connection_id=str(connection_id), id=table[0][0], name=table[0][2], description=table[0][3], @@ -389,6 +397,7 @@ def create_conversation(connection_id: str, name: str) -> int: if conversation_id is None: raise ValueError("Conversation could not be created") + conn.commit() return conversation_id @@ -514,7 +523,7 @@ def get_messages_with_results(conversation_id: str) -> list[MessageWithResults]: return messages -def get_message_history(conversation_id: str) -> list[dict[str, Any]]: +def get_message_history(conversation_id: str) -> list[dict[str, Any]]: # type: ignore[misc] """Returns the message history of a conversation in OpenAI API format""" messages = conn.execute( """SELECT content, role, created_at @@ -528,7 +537,7 @@ def get_message_history(conversation_id: str) -> list[dict[str, Any]]: return [{"role": message[1], "content": message[0]} for message in messages] -def get_message_history_with_selected_tables_with_sql( +def get_message_history_with_selected_tables_with_sql( # type: ignore[misc] conversation_id: str, ) -> list[dict[str, Any]]: """Returns the message history of a conversation with selected tables as a list""" @@ -553,7 +562,7 @@ def get_message_history_with_selected_tables_with_sql( ] -def get_message_history_with_sql(conversation_id: str) -> list[dict[str, Any]]: +def get_message_history_with_sql(conversation_id: str) -> list[dict[str, Any]]: # type: ignore[misc] """Returns the message history of a conversation with the SQL result encoded inside content in OpenAI API format""" messages_with_sql = conn.execute( """SELECT messages.content, messages.role, messages.created_at, results.content diff --git a/text2sql-backend/main.py b/text2sql-backend/main.py index 379b7d99..55c54ed0 100644 --- a/text2sql-backend/main.py +++ b/text2sql-backend/main.py @@ -1,37 +1,29 @@ import json import logging -import re from typing import Annotated +from uuid import UUID from fastapi import Body, Depends, HTTPException, Response -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel from pydantic.json import pydantic_encoder from pygments import lexers from pygments_pprint_sql import SqlFilter -from sqlalchemy import create_engine -from sqlalchemy.exc import OperationalError import db from app import App -from dataline.api.settings.router import router as settings_router -from dataline.config import config from dataline.repositories.base import AsyncSession, NotFoundError, get_session from dataline.services.settings import SettingsService -from dataline.utils import get_sqlite_dsn from models import ( - Connection, ConversationWithMessagesWithResults, DataResult, MessageWithResults, Result, StatusType, SuccessResponse, - TableSchema, UnsavedResult, - UpdateConnectionRequest, UpdateConversationRequest, ) -from services import QueryService, SchemaService, results_from_query_response +from services import QueryService, results_from_query_response from sql_wrapper import request_execute, request_limit logging.basicConfig(level=logging.DEBUG) @@ -43,221 +35,11 @@ app = App() -class ConnectRequest(BaseModel): - dsn: str = Field(min_length=3) - name: str - - @validator("dsn") - def validate_dsn_format(cls, value: str) -> str: - # Define a regular expression to match the DSN format - dsn_regex = r"^[\w\+]+:\/\/[\w-]+:\w+@[\w.-]+[:\d]*\/\w+$" - - if not re.match(dsn_regex, value): - raise ValueError( - 'Invalid DSN format. The expected format is "driver://username:password@host:port/database".' - ) - - return value - - -app.include_router(settings_router) - - @app.get("/healthcheck", response_model_exclude_none=True) async def healthcheck() -> SuccessResponse[None]: return SuccessResponse(status=StatusType.ok) -def create_db_connection(dsn: str, name: str) -> SuccessResponse[dict[str, str]]: - try: - engine = create_engine(dsn) - with engine.connect(): - pass - except OperationalError as exc: - # Try again replacing localhost with host.docker.internal to connect with DBs running in docker - if "localhost" in dsn: - dsn = dsn.replace("localhost", "host.docker.internal") - try: - engine = create_engine(dsn) - with engine.connect(): - pass - except OperationalError as e: - logger.error(e) - raise HTTPException(status_code=404, detail="Failed to connect to database") - else: - logger.error(exc) - raise HTTPException(status_code=404, detail="Failed to connect to database") - - # Check if connection with DSN already exists, then return connection_id - try: - existing_connection = db.get_connection_from_dsn(dsn) - if existing_connection: - return SuccessResponse( - status=StatusType.ok, - data={ - "connection_id": existing_connection.id, - }, - ) - - except NotFoundError: - pass - - # Insert connection only if success - dialect = engine.url.get_dialect().name - database = engine.url.database - - if not database: - raise Exception("Invalid DSN. Database name is required.") - - with db.DatabaseManager() as conn: - connection_id = db.create_connection( - conn, - dsn, - database=database, - name=name, - dialect=dialect, - ) - - SchemaService.create_or_update_tables(conn, connection_id) - conn.commit() # only commit if all step were successful - - return SuccessResponse( - status=StatusType.ok, - data={ - "connection_id": connection_id, - "database": database, - "dialect": dialect, - }, - ) - - -@app.post("/create-sample-db") -async def create_sample_db() -> SuccessResponse[dict[str, str]]: - name = "DVD Rental (Sample)" - dsn = get_sqlite_dsn(config.sample_postgres_path) - return create_db_connection(dsn, name) - - -@app.post("/connect", response_model_exclude_none=True) -async def connect_db(req: ConnectRequest) -> SuccessResponse[dict[str, str]]: - return create_db_connection(req.dsn, req.name) - - -class ConnectionsOut(BaseModel): - connections: list[Connection] - - -@app.get("/connections") -async def get_connections() -> SuccessResponse[ConnectionsOut]: - return SuccessResponse( - status=StatusType.ok, - data=ConnectionsOut( - connections=db.get_connections(), - ), - ) - - -class TableSchemasOut(BaseModel): - tables: list[TableSchema] - - -@app.get("/connection/{connection_id}/schemas") -async def get_table_schemas(connection_id: str) -> SuccessResponse[TableSchemasOut]: - # Check for connection existence - with db.DatabaseManager() as conn: - connection = db.get_connection(conn, connection_id) - if not connection: - raise HTTPException(status_code=404, detail="Invalid connection_id") - - return SuccessResponse( - status=StatusType.ok, - data=TableSchemasOut( - tables=db.get_table_schemas_with_descriptions(connection_id), - ), - ) - - -@app.patch("/schemas/table/{table_id}") -async def update_table_schema_description( - table_id: str, description: Annotated[str, Body(embed=True)] -) -> dict[str, str]: - with db.DatabaseManager() as conn: - db.update_schema_table_description(conn, table_id=table_id, description=description) - conn.commit() - - return {"status": "ok"} - - -@app.patch("/schemas/field/{field_id}") -async def update_table_schema_field_description( - field_id: str, description: Annotated[str, Body(embed=True)] -) -> dict[str, str]: - with db.DatabaseManager() as conn: - db.update_schema_table_field_description(conn, field_id=field_id, description=description) - conn.commit() - - return {"status": "ok"} - - -class ConnectionOut(BaseModel): - connection: Connection - - -@app.get("/connection/{connection_id}") -async def get_connection(connection_id: str) -> SuccessResponse[ConnectionOut]: - with db.DatabaseManager() as conn: - return SuccessResponse( - status=StatusType.ok, - data=ConnectionOut( - connection=db.get_connection(conn, connection_id), - ), - ) - - -@app.delete("/connection/{connection_id}") -async def delete_connection(connection_id: str) -> SuccessResponse[None]: - with db.DatabaseManager() as conn: - db.delete_connection(conn, connection_id) - return SuccessResponse(status=StatusType.ok) - - -@app.patch("/connection/{connection_id}") -async def update_connection(connection_id: str, req: UpdateConnectionRequest) -> SuccessResponse[ConnectionOut]: - # Try to connect to provided dsn - try: - engine = create_engine(req.dsn) - with engine.connect(): - pass - except OperationalError as e: - logger.error(e) - raise HTTPException(status_code=400, detail="Failed to connect to database") - - # Update connection only if success - dialect = engine.url.get_dialect().name - database = str(engine.url.database) - - db.update_connection( - connection_id=connection_id, - dsn=req.dsn, - database=database, - name=req.name, - dialect=dialect, - ) - - return SuccessResponse( - status=StatusType.ok, - data=ConnectionOut( - connection=Connection( - id=connection_id, - dsn=req.dsn, - database=database, - name=req.name, - dialect=dialect, - ), - ), - ) - - class ConversationsOut(BaseModel): conversations: list[ConversationWithMessagesWithResults] @@ -340,8 +122,9 @@ async def execute_sql( # Will raise error that's auto captured by middleware if not exists conversation = db.get_conversation(conversation_id) connection_id = conversation.connection_id - connection = db.get_connection(conn, connection_id) - if not connection: + try: + connection = db.get_connection(conn, UUID(connection_id)) + except NotFoundError: raise HTTPException(status_code=404, detail="Invalid connection_id") openai_key = await settings_service.get_openai_api_key(session) @@ -405,8 +188,9 @@ async def query( # Create query service and generate response connection_id = conversation.connection_id - connection = db.get_connection(conn, connection_id) - if not connection: + try: + connection = db.get_connection(conn, UUID(connection_id)) + except NotFoundError: raise HTTPException(status_code=404, detail="Invalid connection_id") openai_key = await settings_service.get_openai_api_key(session) diff --git a/text2sql-backend/models.py b/text2sql-backend/models.py index e4d91cd3..ce11aed3 100644 --- a/text2sql-backend/models.py +++ b/text2sql-backend/models.py @@ -72,28 +72,6 @@ class Conversation: created_at: datetime -class Connection(BaseModel): - id: str - name: str - database: str - dsn: str - dialect: str - - class Config: - table_name = "connections" - - -class TableSchemaField(BaseModel): - id: str - schema_id: str - name: str - type: str - description: str - is_primary_key: Optional[bool] - is_foreign_key: Optional[bool] - linked_table: Optional[str] - - class TableField(BaseModel): name: str type: str @@ -112,14 +90,6 @@ class TableFieldCreate(BaseModel): foreign_table: str = "" -class TableSchema(BaseModel): - id: str - connection_id: str - name: str - description: str - field_descriptions: list[TableSchemaField] - - @dataclass class ConversationWithMessagesWithResults(Conversation): messages: list[MessageWithResults] diff --git a/text2sql-backend/sql_wrapper.py b/text2sql-backend/sql_wrapper.py index 7bb4b430..1b0484b3 100644 --- a/text2sql-backend/sql_wrapper.py +++ b/text2sql-backend/sql_wrapper.py @@ -10,7 +10,7 @@ from sqlalchemy import text from sqlalchemy.exc import ProgrammingError -from models import TableSchema +from dataline.models.connection.schema import TableSchema logger = logging.getLogger(__name__) @@ -89,9 +89,7 @@ def get_schema_foreign_keys(self) -> dict[str, dict[str, str]]: return schema - def get_schema_with_user_descriptions( - self, descriptions: list[TableSchema] - ) -> dict[str, dict[str, str]]: + def get_schema_with_user_descriptions(self, descriptions: list[TableSchema]) -> dict[str, dict[str, str]]: # Create dict of descriptions descriptions_dict: dict[str, TableSchema] = {} for description in descriptions: @@ -119,11 +117,7 @@ def get_schema_with_user_descriptions( # Create a nicely formatted schema for each table for table, columns in schema.items(): formatted_schema = [] - table_description = ( - descriptions_dict[table].description - if table in descriptions_dict - else "" - ) + table_description = descriptions_dict[table].description if table in descriptions_dict else "" formatted_schema.append(f"Table: {table} : {table_description}\n") for column, column_info in columns.items(): formatted_schema.append(f"Column: {column}") @@ -174,9 +168,7 @@ def run_sql(self, command: str) -> tuple[dict, dict]: return result, {"result": result, "columns": list(cursor.keys())} return {}, {} - def validate_sql( - self, sql_query - ) -> tuple[Literal[True], None] | tuple[Literal[False], str]: + def validate_sql(self, sql_query) -> tuple[Literal[True], None] | tuple[Literal[False], str]: try: # Execute the EXPLAIN statement (without fetching results) conn = self._engine.raw_connection() diff --git a/text2sql-backend/tests/api/test_connection.py b/text2sql-backend/tests/api/test_connection.py new file mode 100644 index 00000000..2a032653 --- /dev/null +++ b/text2sql-backend/tests/api/test_connection.py @@ -0,0 +1,163 @@ +import logging + +import pytest +import pytest_asyncio +from fastapi.testclient import TestClient + +from dataline.models.connection.schema import Connection, TableSchema + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_create_sample_db_connection(client: TestClient) -> None: + response = client.post("/create-sample-db") + + assert response.status_code == 200 + + data = response.json()["data"] + assert data["database"] + assert data["dialect"] == "sqlite" + assert data["is_sample"] is True + assert data["id"] + + +@pytest_asyncio.fixture +async def sample_db(client: TestClient) -> Connection: + response = client.post("/create-sample-db") + return Connection(**response.json()["data"]) + + +@pytest.mark.asyncio +async def test_get_connections(client: TestClient, sample_db: Connection) -> None: + response = client.get("/connections") + + assert response.status_code == 200 + + data = response.json()["data"] + assert data["connections"] + assert len(data["connections"]) == 1 + + connections = data["connections"] + assert connections[0] == sample_db.model_dump(mode="json") + + +@pytest.mark.asyncio +async def test_get_connection(client: TestClient, sample_db: Connection) -> None: + response = client.get(f"/connection/{str(sample_db.id)}") + + assert response.status_code == 200 + + data = response.json()["data"] + assert data["connection"] == sample_db.model_dump(mode="json") + + +@pytest.mark.asyncio +async def test_connect_db(client: TestClient) -> None: + connection_in = { + "dsn": "sqlite:///test.db", + "name": "Test", + } + response = client.post("/connect", json=connection_in) + + assert response.status_code == 200 + + data = response.json()["data"] + assert data["id"] + assert data["dsn"] == connection_in["dsn"] + assert data["name"] == connection_in["name"] + assert data["dialect"] == "sqlite" + assert data["database"] + assert data["is_sample"] is False + + +@pytest.mark.asyncio +async def test_get_table_schemas(client: TestClient, sample_db: Connection) -> None: + response = client.get(f"/connection/{str(sample_db.id)}/schemas") + + assert response.status_code == 200 + + data = response.json()["data"] + assert data["tables"] + assert len(data["tables"]) > 1 + + table_schema = data["tables"][0] + assert table_schema["id"] + assert table_schema["connection_id"] + assert table_schema["name"] is not None + assert table_schema["description"] is not None + + assert len(table_schema["field_descriptions"]) > 0 + + field = table_schema["field_descriptions"][0] + assert field["id"] + assert field["schema_id"] + assert field["name"] + assert field["type"] + assert field["description"] is not None + assert field["is_primary_key"] is not None + assert field["is_foreign_key"] is not None + assert field["linked_table"] is not None + + +@pytest_asyncio.fixture +async def example_table_schema(client: TestClient, sample_db: Connection) -> TableSchema: + response = client.get(f"/connection/{str(sample_db.id)}/schemas") + return TableSchema.model_validate(response.json()["data"]["tables"][0]) + + +@pytest.mark.asyncio +async def test_update_table_schema_description(client: TestClient, example_table_schema: TableSchema) -> None: + update_in = {"description": "New description"} + response = client.patch(f"/schemas/table/{example_table_schema.id}", json=update_in) + + assert response.status_code == 200 + + # Check if the description was updated + response = client.get(f"/connection/{example_table_schema.connection_id}/schemas") + data = response.json()["data"] + table_schema = TableSchema.model_validate(data["tables"][0]) + assert table_schema.description == update_in["description"] + + +@pytest.mark.asyncio +async def test_update_table_schema_field_description(client: TestClient, example_table_schema: TableSchema) -> None: + field = example_table_schema.field_descriptions[0] + update_in = {"description": "New description"} + response = client.patch(f"/schemas/field/{field.id}", json=update_in) + + assert response.status_code == 200 + + # Check if the description was updated + response = client.get(f"/connection/{example_table_schema.connection_id}/schemas") + data = response.json()["data"] + table_schema = TableSchema.model_validate(data["tables"][0]) + field = table_schema.field_descriptions[0] + assert field.description == update_in["description"] + + +@pytest.mark.asyncio +async def test_delete_connection(client: TestClient, sample_db: Connection) -> None: + response = client.delete(f"/connection/{str(sample_db.id)}") + + assert response.status_code == 200 + + # Check if the connection was deleted + response = client.get("/connections") + data = response.json()["data"] + assert len(data["connections"]) == 0 + + +@pytest.mark.asyncio +async def test_update_connection(client: TestClient, sample_db: Connection) -> None: + update_in = { + "dsn": "sqlite:///new.db", + "name": "New name", + } + response = client.patch(f"/connection/{str(sample_db.id)}", json=update_in) + + assert response.status_code == 200 + + data = response.json()["data"] + assert data["connection"]["dsn"] == update_in["dsn"] + assert data["connection"]["name"] == update_in["name"] diff --git a/text2sql-frontend/src/App.tsx b/text2sql-frontend/src/App.tsx index a476665f..128d4de9 100644 --- a/text2sql-frontend/src/App.tsx +++ b/text2sql-frontend/src/App.tsx @@ -7,7 +7,11 @@ import { SnackbarProvider } from "notistack"; import { HealthCheckProvider } from "./components/Providers/HealthcheckProvider"; export const App = () => ( - + diff --git a/text2sql-frontend/src/api.ts b/text2sql-frontend/src/api.ts index 9e31500d..4f4d7b84 100644 --- a/text2sql-frontend/src/api.ts +++ b/text2sql-frontend/src/api.ts @@ -56,12 +56,14 @@ const healthcheck = async (): Promise => { return response.data; }; -type Connection = { - connection_id: string; - database?: string; - dialect?: string; +export type ConnectionResult = { + id: string; + dsn: string; + database: string; + name: string; + dialect: string; }; -type ConnectResult = ApiResponse; +type ConnectResult = ApiResponse; const createConnection = async ( connectionString: string, name: string @@ -80,13 +82,6 @@ const createTestConnection = async (): Promise => { return response.data; }; -export type ConnectionResult = { - id: string; - dsn: string; - database: string; - name: string; - dialect: string; -}; export type ListConnectionsResult = ApiResponse<{ connections: ConnectionResult[]; }>;