Skip to content

Commit 976319a

Browse files
committed
More refactor
1 parent dabba1b commit 976319a

10 files changed

+358
-337
lines changed

preql/api.py

+17-50
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,32 @@
11
from contextlib import contextmanager
22

3-
43
from . import settings
54
from . import pql_ast as ast
65
from . import pql_objects as objects
76
from .utils import classify
87
from .interpreter import Interpreter
9-
from .evaluate import cast_to_python
10-
from .interp_common import create_engine, call_pql_func
8+
from .sql_interface import create_engine
119
from .pql_types import T
12-
from .pql_functions import import_pandas
13-
from .context import context
14-
from . import sql
1510

1611
from . import display
1712
display.install_reprs()
1813

1914

20-
def _call_pql_func(state, name, args):
21-
with context(state=state):
22-
count = call_pql_func(state, name, args)
23-
return cast_to_python(state, count)
24-
25-
2615
class TablePromise:
2716
"""Returned by Preql whenever the result is a table
2817
2918
Fetching values creates queries to database engine
3019
"""
3120

32-
def __init__(self, state, inst):
33-
self._state = state
21+
def __init__(self, interp, inst):
22+
self._interp = interp
3423
self._inst = inst
3524
self._rows = None
3625

3726
def to_json(self):
3827
"Returns table as a list of rows, i.e. ``[{col1: value, col2: value, ...}, ...]``"
3928
if self._rows is None:
40-
self._rows = cast_to_python(self._state, self._inst)
29+
self._rows = self._interp.cast_to_python(self._inst)
4130
assert self._rows is not None
4231
return self._rows
4332

@@ -55,30 +44,30 @@ def __eq__(self, other):
5544

5645
def __len__(self):
5746
"Run a count query on table"
58-
return _call_pql_func(self._state, 'count', [self._inst])
47+
count = self._interp.call_builtin_func('count', [self._inst])
48+
return self._interp.cast_to_python(count)
5949

6050
def __iter__(self):
6151
return iter(self.to_json())
6252

6353
def __getitem__(self, index):
6454
"Run a slice query on table"
65-
with context(state=self._state):
66-
if isinstance(index, slice):
67-
offset = index.start or 0
68-
limit = index.stop - offset
69-
return call_pql_func(self._state, 'limit_offset', [self._inst, ast.make_const(limit), ast.make_const(offset)])
55+
if isinstance(index, slice):
56+
offset = index.start or 0
57+
limit = index.stop - offset
58+
return self._interp.call_builtin_func('limit_offset', [self._inst, ast.make_const(limit), ast.make_const(offset)])
7059

71-
# TODO different debug log level / mode
72-
res ,= cast_to_python(self._state, self[index:index+1])
73-
return res
60+
# TODO different debug log level / mode
61+
res ,= self._interp.cast_to_python(self[index:index+1])
62+
return res
7463

7564
def __repr__(self):
7665
return repr(self.to_json())
7766

7867

7968
def _prepare_instance_for_user(interp, inst):
8069
if inst.type <= T.table:
81-
return TablePromise(interp.state, inst)
70+
return TablePromise(interp, inst)
8271

8372
return interp.localize_obj(inst)
8473

@@ -104,6 +93,7 @@ def __init__(self, db_uri: str='sqlite://:memory:', print_sql: bool=settings.pri
10493
"""
10594
self._db_uri = db_uri
10695
self._print_sql = print_sql
96+
self._auto_create = auto_create
10797
# self.engine.ping()
10898

10999
engine = create_engine(self._db_uri, print_sql=self._print_sql, auto_create=auto_create)
@@ -182,41 +172,18 @@ def commit(self):
182172
def rollback(self):
183173
return self.interp.state.db.rollback()
184174

185-
def _drop_tables(self, *tables):
186-
state = self.interp.state
187-
# XXX temporary. Used for testing
188-
for t in tables:
189-
t = sql._quote(state.db.target, state.db.qualified_name(t))
190-
state.db._execute_sql(T.nulltype, f"DROP TABLE {t};", state)
191175

192176
def import_pandas(self, **dfs):
193177
"""Import pandas.DataFrame instances into SQL tables
194178
195179
Example:
196180
>>> pql.import_pandas(a=df_a, b=df_b)
197181
"""
198-
with self.interp.setup_context():
199-
return list(import_pandas(self.interp.state, dfs))
182+
return self.interp.import_pandas(dfs)
200183

201184

202185
def load_all_tables(self):
203-
table_types = self.interp.state.db.import_table_types(self.interp.state)
204-
table_types_by_schema = classify(table_types, lambda x: x[0], lambda x: x[1:])
205-
206-
for schema_name, table_types in table_types_by_schema.items():
207-
if schema_name:
208-
schema = objects.Module(schema_name, {})
209-
self.interp.set_var(schema_name, schema)
210-
211-
for table_name, table_type in table_types:
212-
db_name = table_type.options['name']
213-
inst = objects.new_table(table_type, db_name)
214-
215-
if schema_name:
216-
schema.namespace[table_name] = inst
217-
else:
218-
if not self.interp.has_var(table_name):
219-
self.interp.set_var(table_name, inst)
186+
return self.interp.load_all_tables()
220187

221188

222189

preql/compiler.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from . import pql_objects as objects
99
from . import pql_ast as ast
1010
from . import sql
11-
from .interp_common import dy, State, assert_type, new_value_instance, evaluate, simplify, call_pql_func, cast_to_python_string, cast_to_python_int
11+
from .interp_common import dy, State, assert_type, new_value_instance, evaluate, simplify, call_builtin_func, cast_to_python_string, cast_to_python_int
1212
from .pql_types import T, Type, Id, ITEM_NAME, dp_inst
1313
from .types_impl import flatten_type, pql_repr, kernel_type
1414
from .casts import cast
@@ -321,7 +321,7 @@ def _contains(state, op, a: T.string, b: T.string):
321321
'in': 'str_contains',
322322
'!in': 'str_notcontains',
323323
}[op]
324-
return call_pql_func(state, f, [a, b])
324+
return call_builtin_func(state, f, [a, b])
325325

326326
@dp_inst
327327
def _contains(state, op, a: T.primitive, b: T.table):
@@ -412,7 +412,7 @@ def _compare(_state, op, a: T.primitive, b: T.primitive):
412412
@dp_inst
413413
def _compare(state, op, a: T.type, b: T.type):
414414
if op == '<=':
415-
return call_pql_func(state, "issubclass", [a, b])
415+
return call_builtin_func(state, "issubclass", [a, b])
416416
if op != '=':
417417
raise exc.Signal.make(T.NotImplementedError, op, f"Cannot compare types using: {op}")
418418
return new_value_instance(a == b)
@@ -491,7 +491,7 @@ def _compile_arith(state, arith, a: T.table, b: T.table):
491491
def _compile_arith(state, arith, a: T.string, b: T.int):
492492
if arith.op != '*':
493493
raise Signal.make(T.TypeError, arith.op, f"Operator '{arith.op}' not supported between string and integer.")
494-
return call_pql_func(state, "repeat", [a, b])
494+
return call_builtin_func(state, "repeat", [a, b])
495495

496496

497497
@dp_inst

preql/display.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from . import pql_objects as objects
1010
from . import pql_ast as ast
1111
from .types_impl import dp_type, pql_repr
12-
from .interp_common import call_pql_func, cast_to_python_int, cast_to_python
12+
from .interp_common import call_builtin_func, cast_to_python_int, cast_to_python
1313

1414
from .context import context
1515

@@ -67,11 +67,8 @@ def pql_repr(t: T.nulltype, value):
6767

6868

6969
def table_limit(table, state, limit, offset=0):
70-
return call_pql_func(state, 'limit_offset', [table, ast.make_const(limit), ast.make_const(offset)])
70+
return call_builtin_func(state, 'limit_offset', [table, ast.make_const(limit), ast.make_const(offset)])
7171

72-
def _call_pql_func(state, name, args):
73-
count = call_pql_func(state, name, args)
74-
return cast_to_python_int(state, count)
7572

7673
def _html_table(name, count_str, rows, offset, has_more, colors):
7774
assert colors
@@ -164,7 +161,7 @@ def _view_table(state, table, size, offset):
164161
def table_repr(self, offset=0):
165162
state = context.state
166163

167-
count = _call_pql_func(state, 'count', [table_limit(self, state, MAX_AUTO_COUNT)])
164+
count = cast_to_python_int(state, call_builtin_func(state, 'count', [table_limit(self, state, MAX_AUTO_COUNT)]))
168165
if count == MAX_AUTO_COUNT:
169166
count_str = f'>={count}'
170167
else:

preql/evaluate.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pathlib import Path
44

55
from .utils import safezip, dataclass, SafeDict, listgen
6-
from .interp_common import assert_type, exclude_fields, call_pql_func, is_global_scope
6+
from .interp_common import assert_type, exclude_fields, call_builtin_func, is_global_scope
77
from .exceptions import InsufficientAccessLevel, ReturnSignal, Signal
88
from . import exceptions as exc
99
from . import pql_objects as objects
@@ -85,14 +85,14 @@ def _execute(state: State, struct_def: ast.StructDef):
8585

8686
def db_query(state: State, sql_code, subqueries=None):
8787
try:
88-
return state.db.query(sql_code, subqueries, state=state)
88+
return state.db.query(sql_code, subqueries)
8989
except exc.DatabaseQueryError as e:
9090
raise Signal.make(T.DbQueryError, None, e.args[0]) from e
9191

9292
def drop_table(state, table_type):
9393
name ,= table_type.options['name'].parts
9494
code = sql.compile_drop_table(state, name)
95-
return state.db.query(code, {}, state=state)
95+
return state.db.query(code, {})
9696

9797

9898
@dy
@@ -125,7 +125,7 @@ def _execute(state: State, table_def: ast.TableDef):
125125
exists = state.db.table_exists(db_name.repr_name)
126126
if exists:
127127
assert not t.options['temporary']
128-
cur_type = state.db.import_table_type(state, db_name.repr_name, None if ellipsis else set(t.elems) | {'id'})
128+
cur_type = state.db.import_table_type(db_name.repr_name, None if ellipsis else set(t.elems) | {'id'})
129129

130130
if ellipsis:
131131
elems_to_add = {Str(n, ellipsis.text_ref): v for n, v in cur_type.elems.items() if n not in t.elems}
@@ -555,7 +555,7 @@ def _call_expr(state, expr):
555555
# TODO fix these once we have proper types
556556
@dy
557557
def test_nonzero(state: State, table: objects.TableInstance):
558-
count = call_pql_func(state, "count", [table])
558+
count = call_builtin_func(state, "count", [table])
559559
return bool(cast_to_python_int(state, count))
560560

561561
@dy

preql/interp_common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def exclude_fields(state, table, fields):
199199
proj = ast.Projection(table, [ast.NamedField(None, ast.Ellipsis(None, exclude=list(fields) ), user_defined=False)])
200200
return evaluate(state, proj)
201201

202-
def call_pql_func(state, name, args):
202+
def call_builtin_func(state, name, args):
203203
"Call a builtin pql function"
204204
builtins = state.ns.get_var('__builtins__')
205205
assert isinstance(builtins, objects.Module)

preql/interpreter.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from pathlib import Path
22
from functools import wraps
33

4+
from .utils import classify
45
from .exceptions import Signal, pql_SyntaxError, ReturnSignal
5-
from .evaluate import State, execute, eval_func_call, import_module, evaluate, localize
6+
from .evaluate import State, execute, eval_func_call, import_module, evaluate, localize, cast_to_python
67
from .parser import parse_stmts
78
from . import pql_ast as ast
89
from . import pql_objects as objects
9-
from .interp_common import new_value_instance
10+
from .interp_common import new_value_instance, call_builtin_func
1011
from .context import context
12+
from .pql_functions import import_pandas
1113

1214
from .pql_functions import internal_funcs, joins
1315
from .pql_types import T, Object
@@ -54,11 +56,11 @@ def _execute_code(self, code, source_file, args=None):
5456
raise Signal(T.SyntaxError, [e.text_ref], e.message)
5557

5658

57-
last = None
5859
if stmts:
5960
if isinstance(stmts[0], ast.Const) and stmts[0].type == T.string:
6061
self.set_var('__doc__', stmts[0].value)
6162

63+
last = None
6264
for stmt in stmts:
6365
try:
6466
last = execute(self.state, stmt)
@@ -109,3 +111,40 @@ def localize_obj(self, obj):
109111
def call_func(self, fname, args):
110112
res = eval_func_call(self.state, self.state.get_var(fname), args)
111113
return evaluate(self.state, res)
114+
115+
@entrypoint
116+
def cast_to_python(self, obj):
117+
return cast_to_python(self.state, obj)
118+
119+
@entrypoint
120+
def call_builtin_func(self, name, args):
121+
return call_builtin_func(self.state, name, args)
122+
123+
@entrypoint
124+
def import_pandas(self, dfs):
125+
return list(import_pandas(self.state, dfs))
126+
127+
@entrypoint
128+
def list_tables(self):
129+
return self.state.db.list_tables()
130+
131+
132+
@entrypoint
133+
def load_all_tables(self):
134+
table_types = self.state.db.import_table_types()
135+
table_types_by_schema = classify(table_types, lambda x: x[0], lambda x: x[1:])
136+
137+
for schema_name, table_types in table_types_by_schema.items():
138+
if schema_name:
139+
schema = objects.Module(schema_name, {})
140+
self.set_var(schema_name, schema)
141+
142+
for table_name, table_type in table_types:
143+
db_name = table_type.options['name']
144+
inst = objects.new_table(table_type, db_name)
145+
146+
if schema_name:
147+
schema.namespace[table_name] = inst
148+
else:
149+
if not self.has_var(table_name):
150+
self.set_var(table_name, inst)

preql/pql_functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def pql_import_table(state: State, name: T.string, columns: T.list[T.string].as_
632632
columns_whitelist = set(columns_whitelist)
633633

634634
# Get table type
635-
t = state.db.import_table_type(state, name_str, columns_whitelist)
635+
t = state.db.import_table_type(name_str, columns_whitelist)
636636
assert t <= T.table
637637

638638
# Get table contents
@@ -721,7 +721,7 @@ def pql_tables(state: State):
721721
The resulting table has two columns: name, and type.
722722
"""
723723
names = state.db.list_tables()
724-
values = [(name, state.db.import_table_type(state, name, None)) for name in names]
724+
values = [(name, state.db.import_table_type(name, None)) for name in names]
725725
tuples = [sql.Tuple(T.list[T.string], [new_str(n).code,new_str(t).code]) for n,t in values]
726726

727727
table_type = T.table(dict(name=T.string, type=T.string))

0 commit comments

Comments
 (0)