From d6d9effe8440b6257c3c673ff71c25ee11341ff0 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 24 Apr 2025 03:09:15 +0000 Subject: [PATCH 1/2] feat(spanner): Google Spanner Driver --- sqlspec/adapters/spanner/__init__.py | 4 + sqlspec/adapters/spanner/config/__init__.py | 3 + sqlspec/adapters/spanner/config/_sync.py | 274 ++++++++ sqlspec/adapters/spanner/driver.py | 657 ++++++++++++++++++ .../test_adapters/test_spanner/__init__.py | 0 .../test_spanner/test_connection.py | 163 +++++ .../test_adapters/test_spanner/test_driver.py | 177 +++++ 7 files changed, 1278 insertions(+) create mode 100644 sqlspec/adapters/spanner/__init__.py create mode 100644 sqlspec/adapters/spanner/config/__init__.py create mode 100644 sqlspec/adapters/spanner/config/_sync.py create mode 100644 sqlspec/adapters/spanner/driver.py create mode 100644 tests/integration/test_adapters/test_spanner/__init__.py create mode 100644 tests/integration/test_adapters/test_spanner/test_connection.py create mode 100644 tests/integration/test_adapters/test_spanner/test_driver.py diff --git a/sqlspec/adapters/spanner/__init__.py b/sqlspec/adapters/spanner/__init__.py new file mode 100644 index 0000000..a8c1db6 --- /dev/null +++ b/sqlspec/adapters/spanner/__init__.py @@ -0,0 +1,4 @@ +from .config import SpannerConfig, SpannerPoolConfig +from .driver import SpannerConnection, SpannerDriver + +__all__ = ("SpannerConfig", "SpannerPoolConfig", "SpannerConnection", "SpannerDriver") diff --git a/sqlspec/adapters/spanner/config/__init__.py b/sqlspec/adapters/spanner/config/__init__.py new file mode 100644 index 0000000..cbab9a8 --- /dev/null +++ b/sqlspec/adapters/spanner/config/__init__.py @@ -0,0 +1,3 @@ +from ._sync import SpannerConfig, SpannerPoolConfig + +__all__ = ("SpannerConfig", "SpannerPoolConfig") diff --git a/sqlspec/adapters/spanner/config/_sync.py b/sqlspec/adapters/spanner/config/_sync.py new file mode 100644 index 0000000..f7b4a4a --- /dev/null +++ b/sqlspec/adapters/spanner/config/_sync.py @@ -0,0 +1,274 @@ +import logging +import threading +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Optional, Union + +from google.cloud.spanner_v1 import Client +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.pool import AbstractSessionPool, FixedSizePool, PingingPool, TransactionPingingPool +from google.cloud.spanner_v1.snapshot import Snapshot +from google.cloud.spanner_v1.transaction import Transaction + +from sqlspec.adapters.spanner.driver import SpannerDriver +from sqlspec.base import SyncDatabaseConfig +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.typing import dataclass_to_dict + +if TYPE_CHECKING: + from collections.abc import Generator + + from google.auth.credentials import Credentials + +# Define the Connection Type alias +SpannerSyncConnection = Union[Snapshot, Transaction] + +# Get logger instance +logger = logging.getLogger("sqlspec") + +__all__ = ("SpannerConfig", "SpannerPoolConfig") + + +@dataclass +class SpannerPoolConfig: + """Configuration for the Spanner session pool. + + Ref: https://cloud.google.com/python/docs/reference/spanner/latest/advanced-session-pool-topics + """ + + pool_type: type[AbstractSessionPool] = FixedSizePool + """The type of session pool to use. Defaults to FixedSizePool.""" + min_sessions: int = 1 + """The minimum number of sessions to keep in the pool.""" + max_sessions: int = 10 + """The maximum number of sessions allowed in the pool.""" + labels: Optional[dict[str, str]] = None + """Labels to apply to sessions created by the pool.""" + ping_interval: int = 300 # Default 5 minutes + """Interval (in seconds) for pinging sessions in PingingPool/TransactionPingingPool.""" + # Add other pool-specific configs as needed, e.g., ping_interval for PingingPool + + +@dataclass +class SpannerConfig( + SyncDatabaseConfig[SpannerSyncConnection, AbstractSessionPool, SpannerDriver] +): # Replace Any with actual Connection/Driver types later + """Synchronous Google Cloud Spanner database Configuration. + + This class provides the configuration for Spanner database connections. + """ + + project: Optional[str] = None + """Google Cloud project ID.""" + instance_id: Optional[str] = None + """Spanner instance ID.""" + database_id: Optional[str] = None + """Spanner database ID.""" + credentials: Optional["Credentials"] = None + """Optional Google Cloud credentials. If None, uses Application Default Credentials.""" + client_options: Optional[dict[str, Any]] = None + """Optional dictionary of client options for the Spanner client.""" + pool_config: Optional[SpannerPoolConfig] = field(default_factory=SpannerPoolConfig) + """Spanner session pool configuration.""" + pool_instance: Optional[AbstractSessionPool] = None + """Optional pre-configured pool instance to use.""" + + # Define actual types + connection_type: "type[SpannerSyncConnection]" = field(init=False, default=Union[Snapshot, Transaction]) # type: ignore + driver_type: "type[SpannerDriver]" = field(init=False, default=SpannerDriver) + + _client: Optional[Client] = field(init=False, default=None, repr=False, hash=False) + _database: Optional[Database] = field(init=False, default=None, repr=False, hash=False) + _ping_thread: "Optional[threading.Thread]" = field(init=False, default=None, repr=False, hash=False) + + def __post_init__(self) -> None: + # Basic check, more robust checks might be needed later + if self.pool_instance and not self.pool_config: + # If a pool instance is provided, we might not need pool_config + pass + elif not self.pool_config: + # Create default if not provided and pool_instance is also None + self.pool_config = SpannerPoolConfig() + + @property + def client(self) -> Client: + """Provides the Spanner Client, creating it if necessary.""" + if self._client is None: + self._client = Client( + project=self.project, + credentials=self.credentials, + client_options=self.client_options, + ) + return self._client + + @property + def database(self) -> Database: + """Provides the Spanner Database instance, creating client, pool, and database if necessary. + + This method ensures that the database instance is created and configured correctly. + It also handles any additional configuration options that may be needed for the database. + + Args: + *args: Additional positional arguments to pass to the database constructor. + **kwargs: Additional keyword arguments to pass to the database constructor. + + Raises: + ImproperConfigurationError: If project, instance, and database IDs are not configured. + + Returns: + The configured database instance. + """ + if self._database is None: + if not self.project or not self.instance_id or not self.database_id: + msg = "Project, instance, and database IDs must be configured." + raise ImproperConfigurationError(msg) + + # Ensure client exists + spanner_client = self.client + # Ensure pool exists (this will create it if needed) + pool = self.provide_pool() + + # Get instance object + instance = spanner_client.instance(self.instance_id) # type: ignore[no-untyped-call] + + # Create the final Database object using the created pool + self._database = instance.database(database_id=self.database_id, pool=pool) + return self._database + + def provide_pool(self, *args: Any, **kwargs: Any) -> AbstractSessionPool: + """Provides the configured session pool, creating it if necessary . + + This method ensures that the session pool is created and configured correctly. + It also handles any additional configuration options that may be needed for the pool. + + Args: + *args: Additional positional arguments to pass to the pool constructor. + **kwargs: Additional keyword arguments to pass to the pool constructor. + + Raises: + ImproperConfigurationError: If pool_config is not set or project, instance, and database IDs are not configured. + + Returns: + The configured session pool. + """ + if self.pool_instance: + return self.pool_instance + + if not self.pool_config: + # This should be handled by __post_init__, but double-check + msg = "pool_config must be set if pool_instance is not provided." + raise ImproperConfigurationError(msg) + + if not self.project or not self.instance_id or not self.database_id: + msg = "Project, instance, and database IDs must be configured to create pool." + raise ImproperConfigurationError(msg) + + instance = self.client.instance(self.instance_id) + + pool_kwargs = dataclass_to_dict(self.pool_config, exclude_empty=True, exclude={"pool_type"}) + + # Only include ping_interval if using a relevant pool type + if not issubclass(self.pool_config.pool_type, (PingingPool, TransactionPingingPool)): + pool_kwargs.pop("ping_interval", None) + + self.pool_instance = self.pool_config.pool_type( + database=Database(database_id=self.database_id, instance=instance), # pyright: ignore + **pool_kwargs, + ) + + # Start pinging thread if applicable and not already running + if isinstance(self.pool_instance, (PingingPool, TransactionPingingPool)) and self._ping_thread is None: + self._ping_thread = threading.Thread( + target=self.pool_instance.ping, + daemon=True, # Ensure thread exits with application + name=f"spanner-ping-{self.project}-{self.instance_id}-{self.database_id}", + ) + self._ping_thread.start() + logger.debug("Started Spanner background ping thread for %s", self.pool_instance) + + return self.pool_instance + + @contextmanager + def provide_connection(self, *args: Any, **kwargs: Any) -> "Generator[SpannerSyncConnection, None, None]": + """Provides a Spanner snapshot context (suitable for reads). + + This method ensures that the connection is created and configured correctly. + It also handles any additional configuration options that may be needed for the connection. + + Args: + *args: Additional positional arguments to pass to the connection constructor. + **kwargs: Additional keyword arguments to pass to the connection constructor. + + Yields: + The configured connection. + """ + db = self.database # Ensure database and pool are initialized + with db.snapshot() as snapshot: + yield snapshot # Replace with actual connection object later + + @contextmanager + def provide_session(self, *args: Any, **kwargs: Any) -> "Generator[SpannerDriver, None, None]": + """Provides a driver instance initialized with a connection context (Snapshot). + + This method ensures that the driver is created and configured correctly. + It also handles any additional configuration options that may be needed for the driver. + + Args: + *args: Additional positional arguments to pass to the driver constructor. + **kwargs: Additional keyword arguments to pass to the driver constructor. + + Yields: + The configured driver. + """ + with self.provide_connection(*args, **kwargs) as connection: + yield self.driver_type(connection) # pyright: ignore + + def close_pool(self) -> None: + """Clears internal references to the pool, database, and client.""" + # Spanner pool doesn't require explicit closing usually. + self.pool_instance = None + self._database = None + self._client = None + # Clear thread reference, but don't need to join (it's daemon) + self._ping_thread = None + + @property + def connection_config_dict(self) -> "dict[str, Any]": + """Returns connection-related parameters.""" + config = { + "project": self.project, + "instance_id": self.instance_id, + "database_id": self.database_id, + "credentials": self.credentials, + "client_options": self.client_options, + } + return {k: v for k, v in config.items() if v is not None} + + @property + def pool_config_dict(self) -> "dict[str, Any]": + """Returns pool configuration parameters. + + This method ensures that the pool configuration is returned correctly. + It also handles any additional configuration options that may be needed for the pool. + + Args: + *args: Additional positional arguments to pass to the pool constructor. + **kwargs: Additional keyword arguments to pass to the pool constructor. + + Raises: + ImproperConfigurationError: If pool_config is not set or project, instance, and database IDs are not configured. + + Returns: + The pool configuration parameters. + """ + if self.pool_config: + return dataclass_to_dict(self.pool_config, exclude_empty=True) + # If pool_config was not initially provided but pool_instance was, + # this method might be called unexpectedly. Add check. + if self.pool_instance: + # We can't reconstruct the config dict from the instance easily. + msg = "Cannot retrieve pool_config_dict when initialized with pool_instance." + raise ImproperConfigurationError(msg) + # Should not be reachable if __post_init__ runs correctly + msg = "pool_config is not set." + raise ImproperConfigurationError(msg) diff --git a/sqlspec/adapters/spanner/driver.py b/sqlspec/adapters/spanner/driver.py new file mode 100644 index 0000000..c9a6ea2 --- /dev/null +++ b/sqlspec/adapters/spanner/driver.py @@ -0,0 +1,657 @@ +import logging +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Union, + cast, + overload, +) + +# Spanner imports +# Use specific imports for clarity and potential type stub resolution +from google.cloud.spanner_v1 import Transaction, exceptions, param_types # pyright: ignore + +# sqlspec imports +from sqlspec.base import ( + SyncDriverAdapterProtocol, +) +from sqlspec.exceptions import NotFoundError, SQLConversionError, SQLParsingError +from sqlspec.mixins import SQLTranslatorMixin +from sqlspec.statement import PARAM_REGEX, SQLStatement +from sqlspec.typing import ModelDTOT, StatementParameterType, T + +if TYPE_CHECKING: + from collections.abc import Sequence + + from google.cloud.spanner_v1.streamed import StreamedResultSet + + # Define Connection types matching base protocol +SpannerConnection = Transaction + + +logger = logging.getLogger("sqlspec") + +__all__ = ("SpannerConnection", "SpannerDriver") + + +# --- Helper Functions ---\ + + +def _spanner_row_to_dict(row: "Sequence[Any]", fields: "list[str]") -> "dict[str, Any]": + """Converts a Spanner result row (sequence) to a dictionary.""" + return dict(zip(fields, row)) + + +# --- Base Parameter Processing (Shared Logic) ---\ + + +def _base_process_sql_params( + sql: str, parameters: "Optional[StatementParameterType]", dialect: str, kwargs: "Optional[dict[str, Any]]" +) -> "tuple[str, Optional[dict[str, Any]]]": + """Process SQL and parameters for Spanner, converting :param -> @param. + + Returns the processed SQL and the parameter dictionary. + """ + stmt = SQLStatement(sql=sql, parameters=parameters, dialect=dialect, kwargs=kwargs or None) + processed_sql, processed_params = stmt.process() + + param_dict: Optional[dict[str, Any]] = None + + if isinstance(processed_params, (list, tuple)): + msg = "Spanner requires named parameters (dict), not positional parameters." + raise SQLParsingError(msg) + if isinstance(processed_params, dict): + param_dict = processed_params + # Convert :param style to @param style for Spanner + processed_sql_parts: list[str] = [] + last_end = 0 + found_params_regex: list[str] = [] + + # Use PARAM_REGEX from statement module + for match in PARAM_REGEX.finditer(processed_sql): + # Skip matches inside quotes or comments if PARAM_REGEX handles them + # (Assuming PARAM_REGEX correctly ignores quoted/commented sections) + # if match.group("dquote") or match.group("squote") or match.group("comment"): + # continue + + var_match = match.group("var_name_colon") + perc_match = match.group("var_name_perc") # Check for %(param)s style too + + if var_match: + var_name = var_match + start_char = ":" + start_idx = match.start("var_name_colon") - 1 # Position of ':' + end_idx = match.end("var_name_colon") + elif perc_match: + # Need to adjust indices for %(...)s structure + var_name = perc_match + start_char = "%(" # This won't be used directly below, just for error msg + start_idx = match.start("var_name_perc") - 2 # Position of '%' + end_idx = match.end("var_name_perc") + 3 # Position after ')s' + else: + continue # Skip non-parameter matches + + found_params_regex.append(var_name) + + if var_name not in param_dict: + msg = ( + f"Named parameter '{start_char}{var_name}' found in SQL but missing from parameters. " + f"SQL: {processed_sql}, Params: {param_dict.keys()}" + ) + raise SQLParsingError(msg) + + processed_sql_parts.extend((processed_sql[last_end:start_idx], f"@{var_name}")) + last_end = end_idx + + processed_sql_parts.append(processed_sql[last_end:]) + final_sql = "".join(processed_sql_parts) + + # If no :param or %(param)s found, but we have a dict, assume user wrote @param directly + if not found_params_regex and param_dict: + logger.debug( + "Dict params provided (%s), but no standard ':%s' or '%%(%s)s' placeholders found. " + "Assuming SQL uses @param directly. SQL: %s", + list(param_dict.keys()), + "param", + "param", + processed_sql, + ) + return processed_sql, param_dict # Return original SQL + + return final_sql, param_dict + + # If parameters is None or not a dict/list/tuple after processing + return processed_sql, None + + +def _get_spanner_param_types(params: "Optional[dict[str, Any]]") -> "dict[str, Any]": + """Generate basic Spanner parameter types (defaults to STRING). + + Placeholder: A more robust implementation would inspect param values. + """ + # TODO: Enhance with actual type inference or allow user override via `param_types` + return dict.fromkeys(params, param_types.STRING) if params else {} + + +# --- Synchronous Driver ---\n + + +class SpannerDriver( + SyncDriverAdapterProtocol["SpannerConnection"], + SQLTranslatorMixin["SpannerConnection"], +): + """Spanner Sync Driver Adapter. + + Operates within a specific Spanner Snapshot or Transaction context. + """ + + dialect: str = "spanner" + + def __init__(self, connection: "SpannerConnection", **kwargs: Any) -> None: + """Initialize with a Spanner Snapshot or Transaction.""" + self.connection = connection + # kwargs are ignored for now, consistent with protocol + + def _process_sql_params( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + **kwargs: Any, + ) -> "tuple[str, Optional[dict[str, Any]]]": + """Process SQL and parameters for Spanner sync driver.""" + return _base_process_sql_params(sql, parameters, self.dialect, kwargs) + + def _execute_sql( + self, sql: str, params: "Optional[dict[str, Any]]", context: "SpannerConnection" + ) -> "StreamedResultSet": + """Executes SQL using the provided Snapshot or Transaction.""" + types = _get_spanner_param_types(params) + try: + return context.execute_sql(sql, params or {}, types) + except exceptions.NotFound as e: + # Intercept NotFound early if possible, though typically raised on iteration + msg = f"Spanner query execution failed: {e}" + raise NotFoundError(msg) from e + except exceptions.InvalidArgument as e: + msg = f"Invalid argument during Spanner query execution: {e}. SQL: {sql}, Params: {params}" + raise SQLParsingError( + msg + ) from e + except Exception as e: + # Catch other potential Spanner or network errors + msg = f"Spanner query execution error: {e}" + raise SQLConversionError(msg) from e + + def _execute_update(self, sql: str, params: "Optional[dict[str, Any]]", transaction: "Transaction") -> int: + """Executes DML using the provided Transaction. + + Returns: + -1 as Spanner's execute_update doesn't directly return row count easily. + """ + types = _get_spanner_param_types(params) + try: + # execute_update returns the commit timestamp on success, not row count. + _ = transaction.execute_update(sql, params or {}, types) + # We return -1 as a placeholder, indicating success without a specific count. + return -1 + except exceptions.NotFound as e: + msg = f"Spanner update execution failed: {e}" + raise NotFoundError(msg) from e + except exceptions.InvalidArgument as e: + msg = f"Invalid argument during Spanner update execution: {e}. SQL: {sql}, Params: {params}" + raise SQLParsingError( + msg + ) from e + except Exception as e: + msg = f"Spanner update execution error: {e}" + raise SQLConversionError(msg) from e + + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Sequence[dict[str, Any]]": ... + + @overload + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Sequence[ModelDTOT]": ... + + def select( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, + ) -> "Sequence[Union[ModelDTOT, dict[str, Any]]]": + """Execute a SELECT query and return all results. + + Args: + sql: The SQL query to execute. + parameters: Optional parameters for the query. + connection: Optional connection to use instead of the default. + schema_type: Optional schema type to convert results to. + **kwargs: Additional keyword arguments. + + Returns: + A sequence of results, either as dictionaries or instances of schema_type. + """ + context = connection or self.connection + processed_sql, params = self._process_sql_params(sql, parameters, **kwargs) + result_set = self._execute_sql(processed_sql, params, context) + + # Convert rows to dictionaries + results = [ + _spanner_row_to_dict(row, [field.name for field in result_set.metadata.row_type.fields]) # pyright: ignore + for row in result_set + ] + + # Convert to schema type if specified + if schema_type: + return [schema_type(**result) for result in results] + + return results + + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + + @overload + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... + + def select_one( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, + ) -> "Union[ModelDTOT, dict[str, Any]]": + """Execute a SELECT query and return the first result. + + Args: + sql: The SQL query to execute. + parameters: Optional parameters for the query. + connection: Optional connection to use instead of the default. + schema_type: Optional schema type to convert results to. + **kwargs: Additional keyword arguments. + + Returns: + The first result, either as a dictionary or an instance of schema_type. + + Raises: + NotFoundError: If no results are found. + """ + context = connection or self.connection + processed_sql, params = self._process_sql_params(sql, parameters, **kwargs) + result_set = self._execute_sql(processed_sql, params, context) + + try: + # Get first row + row = next(result_set) # pyright: ignore + except StopIteration: + msg = "No results found for query" + raise NotFoundError(msg) + + # Convert row to dictionary + result = _spanner_row_to_dict(row, [field.name for field in result_set.metadata.row_type.fields]) # pyright: ignore + + # Convert to schema type if specified + if schema_type: + return schema_type(**result) + + return result + + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[dict[str, Any]]": ... + + @overload + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "Optional[ModelDTOT]": ... + + def select_one_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, + ) -> "Optional[Union[ModelDTOT, dict[str, Any]]]": + """Execute a SELECT query and return the first result or None. + + Args: + sql: The SQL query to execute. + parameters: Optional parameters for the query. + connection: Optional connection to use instead of the default. + schema_type: Optional schema type to convert results to. + **kwargs: Additional keyword arguments. + + Returns: + The first result, either as a dictionary or an instance of schema_type, + or None if no results are found. + """ + context = connection or self.connection + processed_sql, params = self._process_sql_params(sql, parameters, **kwargs) + result_set = self._execute_sql(processed_sql, params, context) + + try: + # Get first row + row = next(result_set) # pyright: ignore + except StopIteration: + return None + + # Convert row to dictionary + result = _spanner_row_to_dict(row, [field.name for field in result_set.metadata.row_type.fields]) # pyright: ignore + + # Convert to schema type if specified + if schema_type: + return schema_type(**result) + + return result + + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> Any: ... + + @overload + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> T: ... + + def select_value( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "Optional[type[T]]" = None, + **kwargs: Any, + ) -> "Union[T, Any]": + """Execute a SELECT query and return the first value of the first result. + + Args: + sql: The SQL query to execute. + parameters: Optional parameters for the query. + connection: Optional connection to use instead of the default. + schema_type: Optional schema type to convert the value to. + **kwargs: Additional keyword arguments. + + Returns: + The first value of the first result, optionally converted to schema_type. + + Raises: + NotFoundError: If no results are found. + """ + context = connection or self.connection + processed_sql, params = self._process_sql_params(sql, parameters, **kwargs) + + try: + # Get first row + row = next(self._execute_sql(processed_sql, params, context)) # pyright: ignore + except StopIteration: + msg = "No results found for query" + raise NotFoundError(msg) + + if schema_type: + return cast("T", schema_type(row[0])) # pyright: ignore + + return row[0] + + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "Optional[Any]": ... + + @overload + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "type[T]", + **kwargs: Any, + ) -> "Optional[T]": ... + + def select_value_or_none( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "Optional[type[T]]" = None, + **kwargs: Any, + ) -> "Optional[Union[T, Any]]": + """Execute a SELECT query and return the first value of the first result or None. + + Args: + sql: The SQL query to execute. + parameters: Optional parameters for the query. + connection: Optional connection to use instead of the default. + schema_type: Optional schema type to convert the value to. + **kwargs: Additional keyword arguments. + + Returns: + The first value of the first result, optionally converted to schema_type, + or None if no results are found. + """ + context = connection or self.connection + processed_sql, params = self._process_sql_params(sql, parameters, **kwargs) + + try: + # Get first row + row = next(self._execute_sql(processed_sql, params, context)) # pyright: ignore + except StopIteration: + return None + + # Convert to schema type if specified + if schema_type: + return cast("T", schema_type(row[0])) # pyright: ignore + + return row[0] + + def insert_update_delete( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + **kwargs: Any, + ) -> int: + """Execute an INSERT, UPDATE, or DELETE statement. + + Args: + sql: The SQL statement to execute. + parameters: Optional parameters for the statement. + connection: Optional connection to use instead of the default. + **kwargs: Additional keyword arguments. + + Returns: + The number of rows affected, or -1 if the count is not available. + + Raises: + SQLConversionError: If the statement execution fails. + """ + context = connection or self.connection + if not isinstance(context, Transaction): # pyright: ignore + msg = "INSERT/UPDATE/DELETE operations require a Transaction" + raise SQLConversionError(msg) + + processed_sql, params = self._process_sql_params(sql, parameters, **kwargs) + return self._execute_update(processed_sql, params, context) + + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: None = None, + **kwargs: Any, + ) -> "dict[str, Any]": ... + + @overload + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "type[ModelDTOT]", + **kwargs: Any, + ) -> "ModelDTOT": ... + + def insert_update_delete_returning( + self, + sql: str, + parameters: "Optional[StatementParameterType]" = None, + /, + *, + connection: "Optional[SpannerConnection]" = None, + schema_type: "Optional[type[ModelDTOT]]" = None, + **kwargs: Any, + ) -> "Union[ModelDTOT, dict[str, Any]]": + """Execute an INSERT, UPDATE, or DELETE statement with RETURNING clause. + + Note: Spanner doesn't support general RETURNING DML, so this method raises an error. + + Args: + sql: The SQL statement to execute. + parameters: Optional parameters for the statement. + connection: Optional connection to use instead of the default. + schema_type: Optional schema type to convert results to. + **kwargs: Additional keyword arguments. + + Raises: + SQLConversionError: Always raised as Spanner doesn't support RETURNING DML. + """ + msg = "Spanner doesn't support RETURNING DML" + raise SQLConversionError(msg) + + def execute_script( + self, + sql: str, # Should contain multiple statements typically + parameters: "Optional[StatementParameterType]" = None, # Params might not be applicable to scripts + /, + *, + connection: "Optional[SpannerConnection]" = None, + **kwargs: Any, + ) -> str: # Protocol expects string status + """Execute a SQL script containing multiple statements. + + Args: + sql: The SQL script to execute. + parameters: Optional parameters for the script. + connection: Optional connection to use instead of the default. + **kwargs: Additional keyword arguments. + + Returns: + A status string indicating success. + + Raises: + SQLConversionError: If the script execution fails. + """ + context = connection or self.connection + if not isinstance(context, Transaction): # pyright: ignore + msg = "Script execution requires a Transaction" + raise SQLConversionError(msg) + + sql, parameters = self._process_sql_params(sql, parameters, **kwargs) + try: + # Execute each statement in the script + for statement in sql.split(";"): + statement = statement.strip() + if statement: + if statement.upper().startswith(("SELECT", "WITH")): + self._execute_sql(statement, parameters, context) + else: + self._execute_update(statement, parameters, context) + return "Script executed successfully" + except Exception as e: + msg = f"Script execution failed: {e}" + raise SQLConversionError(msg) from e diff --git a/tests/integration/test_adapters/test_spanner/__init__.py b/tests/integration/test_adapters/test_spanner/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_adapters/test_spanner/test_connection.py b/tests/integration/test_adapters/test_spanner/test_connection.py new file mode 100644 index 0000000..8bde452 --- /dev/null +++ b/tests/integration/test_adapters/test_spanner/test_connection.py @@ -0,0 +1,163 @@ +"""Spanner Configuration and Connection Integration Tests using pytest-databases.""" + +from typing import Any + +import pytest +from google.cloud.spanner_v1.pool import AbstractSessionPool +from pytest_databases.docker.spanner import SpannerService + +# Import sqlspec types +from sqlspec.adapters.spanner import SpannerConfig, SpannerPoolConfig + + +@pytest.fixture(scope="session") +def spanner_emulator_project(spanner_service: SpannerService) -> str: + return spanner_service.project + + +@pytest.fixture(scope="session") +def spanner_emulator_instance(spanner_service: SpannerService) -> str: + return spanner_service.instance # type: ignore[attr-defined] + + +@pytest.fixture(scope="session") +def spanner_emulator_database(spanner_service: SpannerService) -> str: + return spanner_service.database # type: ignore[attr-defined] + + +@pytest.fixture(scope="module") # Use module scope for config fixtures +def sync_config( + spanner_emulator_project: str, + spanner_emulator_instance: str, + spanner_emulator_database: str, +) -> Any: # -> SpannerConfig: + """Provides a SpannerConfig configured for the emulator.""" + config = SpannerConfig( + pool_config=SpannerPoolConfig( + project=spanner_emulator_project, + instance_id=spanner_emulator_instance, + database_id=spanner_emulator_database, + ) + ) + yield config + # Cleanup pool resources after tests in the module are done + config.close_pool() + + +def test_sync_config_properties(sync_config: Any) -> None: + assert sync_config.is_async is False + assert sync_config.support_connection_pooling is True # Spanner uses pools + assert issubclass(sync_config.driver_type, SpannerSyncDriver) + # Check connection_type can be resolved (might need adjustment based on actual Union) + assert sync_config.connection_type is not None + + +def test_sync_provide_pool(sync_config: Any) -> None: + pool = sync_config.provide_pool() + assert pool is not None + assert isinstance(pool, AbstractSessionPool) # Check type + assert pool is sync_config.pool_instance + pool2 = sync_config.provide_pool() + assert pool is pool2 # Should return the same instance + + +def test_sync_provide_connection(sync_config: Any) -> None: + # provide_connection for Spanner usually yields a Transaction + from google.cloud.spanner_v1.transaction import Transaction # Import here for isinstance check + + with sync_config.provide_connection() as connection: + assert connection is not None + # Check if connection is of expected Spanner sync type (Transaction) + assert isinstance(connection, Transaction) + # Check if context manager cleaned up properly (specific checks depend on impl) + + +def test_sync_provide_session(sync_config: Any) -> None: + from google.cloud.spanner_v1.transaction import Transaction # Import here for isinstance check + + with sync_config.provide_session() as driver: + assert isinstance(driver, SpannerSyncDriver) + assert driver.connection is not None + assert isinstance(driver.connection, Transaction) + + +def test_sync_close_pool(sync_config: Any) -> None: + # Need a fresh config instance for this test to avoid state pollution + # Re-create config based on emulator details + config = SpannerConfig( + project=sync_config.project, + instance_id=sync_config.instance_id, + database_id=sync_config.database_id, + ) + _pool = config.provide_pool() # Ensure pool exists + assert config.pool_instance is not None + # Check internal state if _ping_thread exists and is accessible + assert hasattr(config, "_ping_thread") + assert config._ping_thread is not None + config.close_pool() + assert config.pool_instance is None + assert config._database is None # Check internal cleanup + assert config._client is None + assert hasattr(config, "_ping_thread") + assert config._ping_thread is None + + +# --- Async Tests --- + + +@pytest.mark.asyncio +async def test_async_config_properties(async_config: Any) -> None: + assert async_config.is_async is True + assert async_config.support_connection_pooling is True + assert issubclass(async_config.driver_type, SpannerAsyncDriver) + assert async_config.connection_type is not None + + +@pytest.mark.asyncio +async def test_async_provide_pool(async_config: Any) -> None: + # Pool creation itself is sync in the current config implementation + pool = async_config.provide_pool() + assert pool is not None + assert isinstance(pool, AbstractSessionPool) + assert pool is async_config.pool_instance + pool2 = async_config.provide_pool() + assert pool is pool2 + + +@pytest.mark.asyncio +async def test_async_provide_connection(async_config: Any) -> None: + # provide_connection for Spanner usually yields an AsyncTransaction + async with async_config.provide_connection() as connection: + assert connection is not None + # Check if connection is of expected Spanner async type + assert isinstance(connection, AsyncTransaction) + + +@pytest.mark.asyncio +async def test_async_provide_session(async_config: Any) -> None: + async with async_config.provide_session() as driver: + assert isinstance(driver, SpannerAsyncDriver) + assert driver.connection is not None + assert isinstance(driver.connection, AsyncTransaction) + + +@pytest.mark.asyncio +async def test_async_close_pool(async_config: Any) -> None: + # Need a fresh config instance for this test + config = AsyncSpannerConfig( + project=async_config.project, + instance_id=async_config.instance_id, + database_id=async_config.database_id, + ) + _pool = config.provide_pool() # Ensure pool exists + assert config.pool_instance is not None + # Check internal state if _ping_thread exists and is accessible + assert hasattr(config, "_ping_thread") + assert config._ping_thread is not None # noqa: SLF001 + # Close pool is sync in current implementation + config.close_pool() + assert config.pool_instance is None + assert config._database is None # noqa: SLF001 + assert config._client is None # noqa: SLF001 + assert hasattr(config, "_ping_thread") + assert config._ping_thread is None # noqa: SLF001 diff --git a/tests/integration/test_adapters/test_spanner/test_driver.py b/tests/integration/test_adapters/test_spanner/test_driver.py new file mode 100644 index 0000000..c194ee2 --- /dev/null +++ b/tests/integration/test_adapters/test_spanner/test_driver.py @@ -0,0 +1,177 @@ +"""Spanner Sync and Async Driver Integration Tests using pytest-databases.""" + +import os +from collections.abc import Generator +from dataclasses import dataclass +from typing import Any + +import pytest +from google.cloud.spanner_v1 import Client, Transaction +from pytest_databases.docker.spanner import SpannerService # type: ignore[import-untyped] + +# Import sqlspec types +from sqlspec.adapters.spanner import ( # Assuming these imports exist and are correct + SyncSpannerConfig, # type: ignore[import-error] +) +from sqlspec.exceptions import NotFoundError + + +@pytest.fixture(scope="session") +def spanner_emulator_project(spanner_service: SpannerService) -> str: + """Return the project ID used by the Spanner emulator.""" + return spanner_service.project + + +@pytest.fixture(scope="session") +def spanner_emulator_instance(spanner_service: SpannerService) -> str: + """Return the instance ID used by the Spanner emulator.""" + return spanner_service.instance # type: ignore[attr-defined] + + +@pytest.fixture(scope="session") +def spanner_emulator_database(spanner_service: SpannerService) -> str: + """Return the database ID used by the Spanner emulator.""" + return spanner_service.database # type: ignore[attr-defined] + + +@pytest.fixture(scope="session") +def spanner_emulator_host(spanner_service: SpannerService) -> str: + """Return the host used by the Spanner emulator service.""" + # pytest-databases service might expose host/port if needed for direct client connection + # For config, we typically just need project/instance/database if using emulator host env var + # If direct connection is needed, service.host/service.port would be used. + # We assume the google-cloud-spanner client uses SPANNER_EMULATOR_HOST env var set by pytest-databases. + return os.environ.get("SPANNER_EMULATOR_HOST", "localhost:9010") # Default emulator host + + +@dataclass +class SimpleModel: + id: int + name: str + value: float + + +@pytest.fixture(scope="session") +def spanner_sync_config( + spanner_emulator_project: str, + spanner_emulator_instance: str, + spanner_emulator_database: str, +) -> Any: # -> SyncSpannerConfig: + """Provides a SyncSpannerConfig configured for the pytest-databases emulator.""" + # The google-cloud-spanner client automatically uses SPANNER_EMULATOR_HOST + # environment variable if set, which pytest-databases does. + # So, we don't need to explicitly set credentials or host/port for the emulator. + return SyncSpannerConfig( + project=spanner_emulator_project, + instance_id=spanner_emulator_instance, + database_id=spanner_emulator_database, + # No pool config needed for basic tests, defaults should work + ) + + +@pytest.fixture +def spanner_sync_session( + spanner_sync_config: Any, +) -> Generator[Any, None, None]: # -> Generator[SpannerSyncDriver, None, None]: + """Provides a SpannerSyncDriver session within a transaction.""" + # Use the config's context manager to handle transaction lifecycle + with spanner_sync_config.provide_session() as driver: + assert isinstance(driver.connection, Transaction) # Ensure it's a transaction context + yield driver + # Context manager handles cleanup/commit/rollback + + +# Basic table setup fixture (Sync) +@pytest.fixture(scope="module", autouse=True) +def _setup_sync_table(spanner_sync_config: Any) -> None: # type: ignore[unused-function] + """Ensure the test table exists before running sync tests in the module.""" + # Use a direct client for setup DDL as it might be simpler outside transaction scope + # Note: DDL operations might need specific handling in Spanner (e.g., UpdateDatabaseDdl) + # This setup assumes direct client interaction works with the emulator. + client = Client(project=spanner_sync_config.project) + instance = client.instance(spanner_sync_config.instance_id) + database = instance.database(spanner_sync_config.database_id) + + # Simple check if table exists (may need adjustment based on emulator behavior) + try: + with database.snapshot() as snapshot: + results = snapshot.execute_sql( + "SELECT table_name FROM information_schema.tables WHERE table_name='test_models_sync'" + ) + if list(results): + return + except Exception: + pass + + operation = database.update_ddl([ + """ + CREATE TABLE test_models_sync ( + id INT64 NOT NULL, + name STRING(MAX), + value FLOAT64 + ) PRIMARY KEY (id) + """ + ]) + operation.result(timeout=120) # Wait for DDL operation to complete + + +def test_sync_spanner_insert_select_one(spanner_sync_session: Any) -> None: # SpannerSyncDriver + """Test inserting and selecting a single row synchronously.""" + driver = spanner_sync_session + # Arrange + model_id = 1 + model_name = "sync_test" + model_value = 123.45 + # Ensure clean state within transaction + driver.insert_update_delete("DELETE FROM test_models_sync WHERE id = @id", {"id": model_id}) + + # Act: Insert + # Use insert_update_delete which returns -1 placeholder for Spanner + _ = driver.insert_update_delete( + "INSERT INTO test_models_sync (id, name, value) VALUES (@id, @name, @value)", + parameters={"id": model_id, "name": model_name, "value": model_value}, + ) + + # Act: Select + result = driver.select_one_or_none( + "SELECT id, name, value FROM test_models_sync WHERE id = @id", + parameters={"id": model_id}, + schema_type=SimpleModel, + ) + + # Assert + assert result is not None + assert isinstance(result, SimpleModel) + assert result.id == model_id + assert result.name == model_name + assert result.value == model_value + + +def test_sync_spanner_select_one_or_none_not_found(spanner_sync_session: Any) -> None: # SpannerSyncDriver + """Test selecting a non-existent row synchronously.""" + driver = spanner_sync_session + # Arrange: Ensure ID does not exist + non_existent_id = 999 + driver.insert_update_delete("DELETE FROM test_models_sync WHERE id = @id", {"id": non_existent_id}) + + # Act + result = driver.select_one_or_none( + "SELECT * FROM test_models_sync WHERE id = @id", parameters={"id": non_existent_id}, schema_type=SimpleModel + ) + + # Assert + assert result is None + + +def test_sync_spanner_select_one_raises_not_found(spanner_sync_session: Any) -> None: # SpannerSyncDriver + """Test select_one raises NotFoundError for a non-existent row synchronously.""" + driver = spanner_sync_session + # Arrange: Ensure ID does not exist + non_existent_id = 998 + driver.insert_update_delete("DELETE FROM test_models_sync WHERE id = @id", {"id": non_existent_id}) + + # Act & Assert + with pytest.raises(NotFoundError): + driver.select_one( + "SELECT * FROM test_models_sync WHERE id = @id", parameters={"id": non_existent_id}, schema_type=SimpleModel + ) From 43a67d29409a9d17643edf44fa7f11e0c9817234 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 24 Apr 2025 03:10:40 +0000 Subject: [PATCH 2/2] fix: type --- .../integration/test_adapters/test_spanner/test_connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_adapters/test_spanner/test_connection.py b/tests/integration/test_adapters/test_spanner/test_connection.py index 8bde452..9801ae5 100644 --- a/tests/integration/test_adapters/test_spanner/test_connection.py +++ b/tests/integration/test_adapters/test_spanner/test_connection.py @@ -7,7 +7,7 @@ from pytest_databases.docker.spanner import SpannerService # Import sqlspec types -from sqlspec.adapters.spanner import SpannerConfig, SpannerPoolConfig +from sqlspec.adapters.spanner import SpannerConfig, SpannerDriver, SpannerPoolConfig @pytest.fixture(scope="session") @@ -47,7 +47,7 @@ def sync_config( def test_sync_config_properties(sync_config: Any) -> None: assert sync_config.is_async is False assert sync_config.support_connection_pooling is True # Spanner uses pools - assert issubclass(sync_config.driver_type, SpannerSyncDriver) + assert issubclass(sync_config.driver_type, SpannerDriver) # Check connection_type can be resolved (might need adjustment based on actual Union) assert sync_config.connection_type is not None