Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "wherobots-python-dbapi"
version = "0.19.0"
version = "0.20.0"
description = "Python DB-API driver for Wherobots DB"
authors = [{ name = "Maxime Petazzoni", email = "max@wherobots.com" }]
requires-python = "~=3.8"
Expand Down Expand Up @@ -31,6 +31,7 @@ dev = [
"pre-commit",
"pytest>=8.0.2",
"pyarrow>=14.0.2",
"pyarrow-stubs>=17.0",
"pandas",
"rich>=13.7.1",
"mypy>=1.14.1",
Expand Down
46 changes: 45 additions & 1 deletion tests/smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
from rich.table import Table

from wherobots.db import connect, connect_direct
from wherobots.db.constants import DEFAULT_ENDPOINT, DEFAULT_SESSION_TYPE
from wherobots.db.constants import (
DEFAULT_ENDPOINT,
DEFAULT_SESSION_TYPE,
DEFAULT_STORAGE_FORMAT,
)
from wherobots.db.connection import Connection
from wherobots.db.region import Region
from wherobots.db.session_type import SessionType
from wherobots.db.result_storage import StorageFormat, Store

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -48,7 +53,35 @@
parser.add_argument(
"--wide", help="Enable wide output", action="store_const", const=80, default=30
)
parser.add_argument(
"-s",
"--store",
help="Store results in temporary storage",
action="store_true",
)
parser.add_argument("sql", nargs="+", help="SQL query to execute")

args, unknown = parser.parse_known_args()
if args.store:
parser.add_argument(
"-sf",
"--storage-format",
help="Storage format for the results",
default=DEFAULT_STORAGE_FORMAT,
choices=[sf.value for sf in StorageFormat],
)
parser.add_argument(
"--single",
help="Generate only a single part file",
action="store_true",
)
parser.add_argument(
"-p",
"--presigned-url",
help="Generate a presigned URL for the results (only when --single is set)",
action="store_true",
)

args = parser.parse_args()

logging.basicConfig(
Expand All @@ -73,6 +106,16 @@
token = f.read().strip()
headers = {"Authorization": f"Bearer {token}"}

store = None
if args.store:
store = Store(
format=StorageFormat(args.storage_format)
if args.storage_format
else DEFAULT_STORAGE_FORMAT,
single=args.single,
generate_presigned_url=args.presigned_url,
)

if args.ws_url:
conn_func = functools.partial(connect_direct, uri=args.ws_url, headers=headers)
else:
Expand All @@ -86,6 +129,7 @@
region=Region(args.region) if args.region else Region.AWS_US_WEST_2,
version=args.version,
session_type=SessionType(args.session_type),
store=store,
)

def render(results: pandas.DataFrame) -> None:
Expand Down
40 changes: 38 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 16 additions & 3 deletions wherobots/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from wherobots.db.cursor import Cursor
from wherobots.db.errors import NotSupportedError, OperationalError
from wherobots.db.result_storage import Store


@dataclass
Expand Down Expand Up @@ -56,23 +57,25 @@ def __init__(
results_format: Union[ResultsFormat, None] = None,
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
store: Union[Store, None] = None,
):
self.__ws = ws
self.__read_timeout = read_timeout
self.__results_format = results_format
self.__data_compression = data_compression
self.__geometry_representation = geometry_representation
self.__store = store

self.__queries: dict[str, Query] = {}
self.__thread = threading.Thread(
target=self.__main_loop, daemon=True, name="wherobots-connection"
)
self.__thread.start()

def __enter__(self):
def __enter__(self) -> "Connection":
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()

def close(self) -> None:
Expand Down Expand Up @@ -134,6 +137,9 @@ def __listen(self) -> None:
# On a state_updated event telling us the query succeeded,
# ask for results.
if kind == EventKind.STATE_UPDATED:
logging.info(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to keep that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and no. Since result_uri and size is currently part of the state updated event (and not the retrieve result) we might want to have a way to surface that info programatically. Right now it's an easy way to just log dump it, but honestly it's a poor way to surface it.

We should talk about really making it part of a request result call even if it's another extra round or have a nicer way to handle it from the client side.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. My recommendation is to extend the interface of Connection beyond what the DB-API interface calls for, as that part is specifically for "DB driver"-like work where you get query results directly through the Cursor from the cursor() method.

Maybe we provide another, different type of handler:

def cursor(self) -> Cursor:
    return Cursor(self.__execute_sql, self.__cancel_query)

def stored(self, format: ResultFormat, single: bool, presigned: bool) -> StoredResult:
    return StoredResult(self.__execute_sql, self.__cancel_query, format, single, presigned)

Which exposes:

class StoredResult:
    def __init__(self, exec_fn, cancel_fn, format: ResultFormat, single: bool, presigned: bool):
        # ...

    def execute(self, operation: str, parameters: dict[str, Any] = None):
        # ... call exec_fn with the correct store{} parameters

    def get_result_url(self) -> str:
        # ...

"Query %s succeeded; full message is %s", execution_id, message
)
self.__request_results(execution_id)
return

Expand Down Expand Up @@ -169,7 +175,7 @@ def _handle_results(self, execution_id: str, results: Dict[str, Any]) -> Any:
result_compression = results.get("compression")
logging.info(
"Received %d bytes of %s-compressed %s results from %s.",
len(result_bytes),
len(result_bytes) if result_bytes else 0,
result_compression,
result_format,
execution_id,
Expand Down Expand Up @@ -209,6 +215,13 @@ def __execute_sql(self, sql: str, handler: Callable[[Any], None]) -> str:
"statement": sql,
}

if self.__store:
request["store"] = {
"format": self.__store.format.value if self.__store.format else None,
"single": str(self.__store.single),
"generate_presigned_url": str(self.__store.generate_presigned_url),
}

self.__queries[execution_id] = Query(
sql=sql,
execution_id=execution_id,
Expand Down
2 changes: 2 additions & 0 deletions wherobots/db/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .region import Region
from .runtime import Runtime
from .session_type import SessionType
from .result_storage import StorageFormat


DEFAULT_ENDPOINT: str = "api.cloud.wherobots.com" # "api.cloud.wherobots.com"
Expand All @@ -14,6 +15,7 @@
DEFAULT_REGION: Region = Region.AWS_US_WEST_2
DEFAULT_VERSION: str = "latest"
DEFAULT_SESSION_TYPE: SessionType = SessionType.MULTI
DEFAULT_STORAGE_FORMAT: StorageFormat = StorageFormat.PARQUET
DEFAULT_READ_TIMEOUT_SECONDS: float = 0.25
DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS: float = 900

Expand Down
5 changes: 5 additions & 0 deletions wherobots/db/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from .region import Region
from .runtime import Runtime
from .result_storage import Store

apilevel = "2.0"
threadsafety = 1
Expand Down Expand Up @@ -72,6 +73,7 @@ def connect(
results_format: Union[ResultsFormat, None] = None,
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
store: Union[Store, None] = None,
) -> Connection:
if not token and not api_key:
raise ValueError("At least one of `token` or `api_key` is required")
Expand Down Expand Up @@ -157,6 +159,7 @@ def get_session_uri() -> str:
results_format=results_format,
data_compression=data_compression,
geometry_representation=geometry_representation,
store=store,
)


Expand All @@ -177,6 +180,7 @@ def connect_direct(
results_format: Union[ResultsFormat, None] = None,
data_compression: Union[DataCompression, None] = None,
geometry_representation: Union[GeometryRepresentation, None] = None,
store: Union[Store, None] = None,
) -> Connection:
uri_with_protocol = f"{uri}/{protocol}"

Expand All @@ -199,4 +203,5 @@ def connect_direct(
results_format=results_format,
data_compression=data_compression,
geometry_representation=geometry_representation,
store=store,
)
22 changes: 22 additions & 0 deletions wherobots/db/result_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dataclasses import dataclass
from enum import auto
from strenum import LowercaseStrEnum
from typing import Union


class StorageFormat(LowercaseStrEnum):
PARQUET = auto()
CSV = auto()
GEOJSON = auto()


@dataclass
class Store:
format: Union[StorageFormat, None] = None
single: bool = False
generate_presigned_url: bool = False

def __post_init__(self) -> None:
assert (
self.single or not self.generate_presigned_url
), "Presigned URL can only be generated when single part file is requested."