diff --git a/pyproject.toml b/pyproject.toml index 82664c4..7dc6141 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "wherobots-python-dbapi" [tool.poetry] name = "wherobots-python-dbapi" -version = "0.15.0" +version = "0.16.0" description = "Python DB-API driver for Wherobots DB" authors = ["Maxime Petazzoni "] license = "Apache 2.0" diff --git a/tests/smoke.py b/tests/smoke.py index c3fc5b0..c4f9945 100644 --- a/tests/smoke.py +++ b/tests/smoke.py @@ -11,10 +11,15 @@ from rich.table import Table from wherobots.db import connect, connect_direct -from wherobots.db.constants import DEFAULT_ENDPOINT, DEFAULT_SESSION_TYPE +from wherobots.db.constants import ( + DEFAULT_ENDPOINT, + DEFAULT_SESSION_TYPE, + DEFAULT_STORAGE_FORMAT, +) from wherobots.db.connection import Connection from wherobots.db.region import Region from wherobots.db.session_type import SessionType +from wherobots.db.result_storage import StorageFormat, Store if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -47,7 +52,35 @@ parser.add_argument( "--wide", help="Enable wide output", action="store_const", const=80, default=30 ) + parser.add_argument( + "-s", + "--store", + help="Store results in temporary storage", + action="store_true", + ) parser.add_argument("sql", nargs="+", help="SQL query to execute") + + args, unknown = parser.parse_known_args() + if args.store: + parser.add_argument( + "-sf", + "--storage-format", + help="Storage format for the results", + default=DEFAULT_STORAGE_FORMAT, + choices=[sf.value for sf in StorageFormat], + ) + parser.add_argument( + "--single", + help="Generate only a single part file", + action="store_true", + ) + parser.add_argument( + "-p", + "--presigned-url", + help="Generate a presigned URL for the results (only when --single is set)", + action="store_true", + ) + args = parser.parse_args() logging.basicConfig( @@ -72,6 +105,16 @@ token = f.read().strip() headers = {"Authorization": f"Bearer {token}"} + store = None + if args.store: + store = Store( + format=StorageFormat(args.storage_format) + if args.storage_format + else DEFAULT_STORAGE_FORMAT, + single=args.single, + generate_presigned_url=args.presigned_url, + ) + if args.ws_url: conn_func = functools.partial(connect_direct, uri=args.ws_url, headers=headers) else: @@ -84,6 +127,7 @@ wait_timeout=900, region=Region(args.region) if args.region else Region.AWS_US_WEST_2, session_type=SessionType(args.session_type), + store=store, ) def render(results: pandas.DataFrame) -> None: diff --git a/wherobots/db/connection.py b/wherobots/db/connection.py index 47bbf61..7b0a586 100644 --- a/wherobots/db/connection.py +++ b/wherobots/db/connection.py @@ -24,6 +24,7 @@ ) from wherobots.db.cursor import Cursor from wherobots.db.errors import NotSupportedError, OperationalError +from wherobots.db.result_storage import Store @dataclass @@ -56,12 +57,14 @@ def __init__( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + store: Union[Store, None] = None, ): self.__ws = ws self.__read_timeout = read_timeout self.__results_format = results_format self.__data_compression = data_compression self.__geometry_representation = geometry_representation + self.__store = store self.__queries: dict[str, Query] = {} self.__thread = threading.Thread( @@ -134,6 +137,9 @@ def __listen(self) -> None: # On a state_updated event telling us the query succeeded, # ask for results. if kind == EventKind.STATE_UPDATED: + logging.info( + "Query %s succeeded; full message is %s", execution_id, message + ) self.__request_results(execution_id) return @@ -209,6 +215,17 @@ def __execute_sql(self, sql: str, handler: Callable[[Any], None]) -> str: "statement": sql, } + if self.__store: + request["store"] = {} + if self.__store.format: + request["store"]["format"] = self.__store.format.value + if self.__store.single: + request["store"]["single"] = str(self.__store.single) + if self.__store.generate_presigned_url: + request["store"]["generate_presigned_url"] = str( + self.__store.generate_presigned_url + ) + self.__queries[execution_id] = Query( sql=sql, execution_id=execution_id, diff --git a/wherobots/db/constants.py b/wherobots/db/constants.py index 53c5b4b..7eb1e10 100644 --- a/wherobots/db/constants.py +++ b/wherobots/db/constants.py @@ -5,6 +5,7 @@ from .region import Region from .runtime import Runtime from .session_type import SessionType +from .result_storage import StorageFormat DEFAULT_ENDPOINT: str = "api.cloud.wherobots.com" # "api.cloud.wherobots.com" @@ -13,6 +14,7 @@ DEFAULT_RUNTIME: Runtime = Runtime.TINY DEFAULT_REGION: Region = Region.AWS_US_WEST_2 DEFAULT_SESSION_TYPE: SessionType = SessionType.SINGLE +DEFAULT_STORAGE_FORMAT: StorageFormat = StorageFormat.PARQUET DEFAULT_READ_TIMEOUT_SECONDS: float = 0.25 DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS: float = 900 diff --git a/wherobots/db/driver.py b/wherobots/db/driver.py index 81573b1..e3e0c87 100644 --- a/wherobots/db/driver.py +++ b/wherobots/db/driver.py @@ -38,6 +38,7 @@ ) from .region import Region from .runtime import Runtime +from .result_storage import Store apilevel = "2.0" threadsafety = 1 @@ -69,6 +70,7 @@ def connect( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + store: Union[Store, None] = None, ) -> Connection: if not token and not api_key: raise ValueError("At least one of `token` or `api_key` is required") @@ -151,6 +153,7 @@ def get_session_uri() -> str: results_format=results_format, data_compression=data_compression, geometry_representation=geometry_representation, + store=store, ) @@ -171,6 +174,7 @@ def connect_direct( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + store: Union[Store, None] = None, ) -> Connection: uri_with_protocol = f"{uri}/{protocol}" @@ -193,4 +197,5 @@ def connect_direct( results_format=results_format, data_compression=data_compression, geometry_representation=geometry_representation, + store=store, ) diff --git a/wherobots/db/result_storage.py b/wherobots/db/result_storage.py new file mode 100644 index 0000000..69362ef --- /dev/null +++ b/wherobots/db/result_storage.py @@ -0,0 +1,31 @@ +from enum import auto +from strenum import LowercaseStrEnum +from typing import Union + + +class StorageFormat(LowercaseStrEnum): + PARQUET = auto() + CSV = auto() + GEOJSON = auto() + GEOPARQUET = auto() + + +class Store: + def __init__( + self, + format: Union[StorageFormat, None] = None, + single: bool = False, + generate_presigned_url: bool = False, + ): + self.format = format + self.single = single + self.generate_presigned_url = generate_presigned_url + assert ( + single or not generate_presigned_url + ), "Presigned URL can only be generated when single part file is requested." + + def __repr__(self): + return f"Store(format={self.format}, single={self.single}, generate_presigned_url={self.generate_presigned_url})" + + def __str__(self): + return f"Store(format={self.format}, single={self.single}, generate_presigned_url={self.generate_presigned_url})"