Skip to content

Commit

Permalink
Merge pull request #410 from machow/fix-mock-engine
Browse files Browse the repository at this point in the history
fix: mock_sqlalchemy_engine across dialects, collecting
  • Loading branch information
machow authored Mar 29, 2022
2 parents 29595a6 + bc6fdb6 commit 949642e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
15 changes: 13 additions & 2 deletions siuba/sql/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import importlib

try:
# once we drop sqlalchemy 1.2, can use create_mock_engine function
from sqlalchemy.engine.mock import MockConnection
except ImportError:
# monkey patch old sqlalchemy mock, so it can be a context handler
from sqlalchemy.engine.strategies import MockEngineStrategy
MockConnection = MockEngineStrategy.MockConnection


def get_dialect_translator(name):
mod = importlib.import_module('siuba.sql.dialects.{}'.format(name))
return mod.translator
Expand Down Expand Up @@ -37,11 +46,13 @@ def mock_sqlalchemy_engine(dialect):
show_query(query)
"""

from sqlalchemy.engine import Engine
from sqlalchemy.dialects import registry

dialect_cls = registry.load('postgresql')
return Engine(None, dialect_cls(), '')
dialect_cls = registry.load(dialect)

return MockConnection(dialect_cls(), lambda *args, **kwargs: None)


# Temporary fix for pandas bug (https://github.com/pandas-dev/pandas/issues/35484)
Expand Down
16 changes: 15 additions & 1 deletion siuba/sql/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@
)

from .translate import CustomOverClause, SqlColumn, SqlColumnAgg
from .utils import get_dialect_translator, _FixedSqlDatabase, _sql_select, _sql_column_collection, _sql_add_columns, _sql_with_only_columns
from .utils import (
get_dialect_translator,
_FixedSqlDatabase,
_sql_select,
_sql_column_collection,
_sql_add_columns,
_sql_with_only_columns,
MockConnection
)

from sqlalchemy import sql
import sqlalchemy
Expand Down Expand Up @@ -467,6 +475,12 @@ def _collect(__data, as_df = True):
# compile_kwargs = {"literal_binds": True}
#)

if isinstance(__data.source, MockConnection):
# a mock sqlalchemy is being used to show_query, and echo queries.
# it doesn't return a result object or have a context handler, so
# we need to bail out early
return

with __data.source.connect() as conn:
if as_df:
sql_db = _FixedSqlDatabase(conn)
Expand Down
16 changes: 15 additions & 1 deletion siuba/tests/test_sql_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from siuba.sql.utils import get_dialect_translator
from siuba.sql.utils import get_dialect_translator, mock_sqlalchemy_engine
from siuba.sql.verbs import collect
from siuba.sql import LazyTbl
import pytest

@pytest.mark.parametrize('name', [
Expand All @@ -8,3 +10,15 @@
])
def test_get_dialect_translator(name):
get_dialect_translator(name)

def test_mock_sqlalchemy_engine_dialect():
engine = mock_sqlalchemy_engine("postgresql")
assert engine.dialect.name == "postgresql"

engine = mock_sqlalchemy_engine("sqlite")
assert engine.dialect.name == "sqlite"

def test_mock_sqlalchemy_engine_no_collect():
engine = mock_sqlalchemy_engine("sqlite")
tbl = LazyTbl(engine, "some_table", ["x"])
assert collect(tbl) is None

0 comments on commit 949642e

Please sign in to comment.