diff --git a/pyproject.toml b/pyproject.toml index e94eeaa..15ac790 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "wherobots-python-dbapi" -version = "0.19.0" +version = "0.20.0" description = "Python DB-API driver for Wherobots DB" authors = [{ name = "Maxime Petazzoni", email = "max@wherobots.com" }] requires-python = "~=3.8" @@ -31,6 +31,7 @@ dev = [ "pre-commit", "pytest>=8.0.2", "pyarrow>=14.0.2", + "pyarrow-stubs>=17.0", "pandas", "rich>=13.7.1", "mypy>=1.14.1", diff --git a/tests/smoke.py b/tests/smoke.py index 9487839..ac53695 100644 --- a/tests/smoke.py +++ b/tests/smoke.py @@ -11,10 +11,15 @@ from rich.table import Table from wherobots.db import connect, connect_direct -from wherobots.db.constants import DEFAULT_ENDPOINT, DEFAULT_SESSION_TYPE +from wherobots.db.constants import ( + DEFAULT_ENDPOINT, + DEFAULT_SESSION_TYPE, + DEFAULT_STORAGE_FORMAT, +) from wherobots.db.connection import Connection from wherobots.db.region import Region from wherobots.db.session_type import SessionType +from wherobots.db.result_storage import StorageFormat, Store if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -48,7 +53,35 @@ parser.add_argument( "--wide", help="Enable wide output", action="store_const", const=80, default=30 ) + parser.add_argument( + "-s", + "--store", + help="Store results in temporary storage", + action="store_true", + ) parser.add_argument("sql", nargs="+", help="SQL query to execute") + + args, unknown = parser.parse_known_args() + if args.store: + parser.add_argument( + "-sf", + "--storage-format", + help="Storage format for the results", + default=DEFAULT_STORAGE_FORMAT, + choices=[sf.value for sf in StorageFormat], + ) + parser.add_argument( + "--single", + help="Generate only a single part file", + action="store_true", + ) + parser.add_argument( + "-p", + "--presigned-url", + help="Generate a presigned URL for the results (only when --single is set)", + action="store_true", + ) + args = parser.parse_args() logging.basicConfig( @@ -73,6 +106,16 @@ token = f.read().strip() headers = {"Authorization": f"Bearer {token}"} + store = None + if args.store: + store = Store( + format=StorageFormat(args.storage_format) + if args.storage_format + else DEFAULT_STORAGE_FORMAT, + single=args.single, + generate_presigned_url=args.presigned_url, + ) + if args.ws_url: conn_func = functools.partial(connect_direct, uri=args.ws_url, headers=headers) else: @@ -86,6 +129,7 @@ region=Region(args.region) if args.region else Region.AWS_US_WEST_2, version=args.version, session_type=SessionType(args.session_type), + store=store, ) def render(results: pandas.DataFrame) -> None: diff --git a/uv.lock b/uv.lock index 1bcd952..36f51c0 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.8, <4" resolution-markers = [ "python_full_version >= '3.12'", @@ -1027,6 +1027,39 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/15/6b30e77872012bbfe8265d42a01d5b3c17ef0ac0f2fae531ad91b6a6c02e/pyarrow-21.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdc4c17afda4dab2a9c0b79148a43a7f4e1094916b3e18d8975bfd6d6d52241f", size = 26227521, upload-time = "2025-07-18T00:57:29.119Z" }, ] +[[package]] +name = "pyarrow-stubs" +version = "17.19" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.9'", +] +dependencies = [ + { name = "pyarrow", version = "17.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7b/a7/fc696ec6905698853821ea20b9b46fc1b0a8ddee7a17cda2638c03ac9d8d/pyarrow_stubs-17.19.tar.gz", hash = "sha256:45af6b05ebb2c84352a111c54063db1be8bf54274b8af709c7f1c034e8d84527", size = 83675, upload-time = "2025-03-17T14:15:38.581Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/46/985eae0eb798a749ed774f4211dfa3d06fd29d297a68e247000b33f504fc/pyarrow_stubs-17.19-py3-none-any.whl", hash = "sha256:9909adfbcb1356e859d5e57967d60a13e59030ea223acbcc8693ec4683982403", size = 80170, upload-time = "2025-03-17T14:15:37.091Z" }, +] + +[[package]] +name = "pyarrow-stubs" +version = "20.0.0.20250825" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", + "python_full_version == '3.10.*'", + "python_full_version == '3.9.*'", +] +dependencies = [ + { name = "pyarrow", version = "21.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/03/2c/2807ba3808971a8870686304a727908f84903be8ede36a3a399a0f36a13d/pyarrow_stubs-20.0.0.20250825.tar.gz", hash = "sha256:e128e575c00a978c851d7fb2f45bf793c3e4dda5c084cfb9e20cf839829c97d9", size = 236556, upload-time = "2025-08-25T02:01:19.92Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/5f/6233b7072f3b635dd29a42cc7d1c9fee8460bf86d4089a88cbf2e1c3580f/pyarrow_stubs-20.0.0.20250825-py3-none-any.whl", hash = "sha256:f6a5242c7874f89fb5c2d8f611dca2ec1125622b53067994a42fa64193ab8d29", size = 235709, upload-time = "2025-08-25T02:01:21.17Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -1593,7 +1626,7 @@ wheels = [ [[package]] name = "wherobots-python-dbapi" -version = "0.19.0" +version = "0.20.0" source = { editable = "." } dependencies = [ { name = "cbor2" }, @@ -1631,6 +1664,8 @@ dev = [ { name = "pre-commit", version = "4.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "pyarrow", version = "17.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "pyarrow", version = "21.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "pyarrow-stubs", version = "17.19", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "pyarrow-stubs", version = "20.0.0.20250825", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "pytest", version = "8.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "pytest", version = "8.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "rich" }, @@ -1658,6 +1693,7 @@ dev = [ { name = "pandas" }, { name = "pre-commit" }, { name = "pyarrow", specifier = ">=14.0.2" }, + { name = "pyarrow-stubs", specifier = ">=17.0" }, { name = "pytest", specifier = ">=8.0.2" }, { name = "rich", specifier = ">=13.7.1" }, ] diff --git a/wherobots/db/connection.py b/wherobots/db/connection.py index 47bbf61..9ea4612 100644 --- a/wherobots/db/connection.py +++ b/wherobots/db/connection.py @@ -24,6 +24,7 @@ ) from wherobots.db.cursor import Cursor from wherobots.db.errors import NotSupportedError, OperationalError +from wherobots.db.result_storage import Store @dataclass @@ -56,12 +57,14 @@ def __init__( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + store: Union[Store, None] = None, ): self.__ws = ws self.__read_timeout = read_timeout self.__results_format = results_format self.__data_compression = data_compression self.__geometry_representation = geometry_representation + self.__store = store self.__queries: dict[str, Query] = {} self.__thread = threading.Thread( @@ -69,10 +72,10 @@ def __init__( ) self.__thread.start() - def __enter__(self): + def __enter__(self) -> "Connection": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() def close(self) -> None: @@ -134,6 +137,9 @@ def __listen(self) -> None: # On a state_updated event telling us the query succeeded, # ask for results. if kind == EventKind.STATE_UPDATED: + logging.info( + "Query %s succeeded; full message is %s", execution_id, message + ) self.__request_results(execution_id) return @@ -169,7 +175,7 @@ def _handle_results(self, execution_id: str, results: Dict[str, Any]) -> Any: result_compression = results.get("compression") logging.info( "Received %d bytes of %s-compressed %s results from %s.", - len(result_bytes), + len(result_bytes) if result_bytes else 0, result_compression, result_format, execution_id, @@ -209,6 +215,13 @@ def __execute_sql(self, sql: str, handler: Callable[[Any], None]) -> str: "statement": sql, } + if self.__store: + request["store"] = { + "format": self.__store.format.value if self.__store.format else None, + "single": str(self.__store.single), + "generate_presigned_url": str(self.__store.generate_presigned_url), + } + self.__queries[execution_id] = Query( sql=sql, execution_id=execution_id, diff --git a/wherobots/db/constants.py b/wherobots/db/constants.py index 95f2555..0e88579 100644 --- a/wherobots/db/constants.py +++ b/wherobots/db/constants.py @@ -5,6 +5,7 @@ from .region import Region from .runtime import Runtime from .session_type import SessionType +from .result_storage import StorageFormat DEFAULT_ENDPOINT: str = "api.cloud.wherobots.com" # "api.cloud.wherobots.com" @@ -14,6 +15,7 @@ DEFAULT_REGION: Region = Region.AWS_US_WEST_2 DEFAULT_VERSION: str = "latest" DEFAULT_SESSION_TYPE: SessionType = SessionType.MULTI +DEFAULT_STORAGE_FORMAT: StorageFormat = StorageFormat.PARQUET DEFAULT_READ_TIMEOUT_SECONDS: float = 0.25 DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS: float = 900 diff --git a/wherobots/db/driver.py b/wherobots/db/driver.py index 9f3a6ab..7d73bf4 100644 --- a/wherobots/db/driver.py +++ b/wherobots/db/driver.py @@ -40,6 +40,7 @@ ) from .region import Region from .runtime import Runtime +from .result_storage import Store apilevel = "2.0" threadsafety = 1 @@ -72,6 +73,7 @@ def connect( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + store: Union[Store, None] = None, ) -> Connection: if not token and not api_key: raise ValueError("At least one of `token` or `api_key` is required") @@ -157,6 +159,7 @@ def get_session_uri() -> str: results_format=results_format, data_compression=data_compression, geometry_representation=geometry_representation, + store=store, ) @@ -177,6 +180,7 @@ def connect_direct( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + store: Union[Store, None] = None, ) -> Connection: uri_with_protocol = f"{uri}/{protocol}" @@ -199,4 +203,5 @@ def connect_direct( results_format=results_format, data_compression=data_compression, geometry_representation=geometry_representation, + store=store, ) diff --git a/wherobots/db/result_storage.py b/wherobots/db/result_storage.py new file mode 100644 index 0000000..a279dbe --- /dev/null +++ b/wherobots/db/result_storage.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from enum import auto +from strenum import LowercaseStrEnum +from typing import Union + + +class StorageFormat(LowercaseStrEnum): + PARQUET = auto() + CSV = auto() + GEOJSON = auto() + + +@dataclass +class Store: + format: Union[StorageFormat, None] = None + single: bool = False + generate_presigned_url: bool = False + + def __post_init__(self) -> None: + assert ( + self.single or not self.generate_presigned_url + ), "Presigned URL can only be generated when single part file is requested."