Skip to content

Commit 706131b

Browse files
committed
Added transaction{}, option to autocommit
1 parent 4fd498c commit 706131b

11 files changed

+106
-14
lines changed

preql/__main__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
help="database url (postgres://user:password@host:port/db_name")
2222
parser.add_argument('--python-traceback', action='store_true',
2323
help="Show the Python traceback when an exception causes the interpreter to quit")
24+
parser.add_argument('--autocommit', action='store_true')
2425

2526

2627
def find_dot_preql():
@@ -59,7 +60,7 @@ def main():
5960

6061

6162

62-
kw = {'print_sql': args.print_sql}
63+
kw = {'print_sql': args.print_sql, 'autocommit': args.autocommit}
6364
if args.database:
6465
kw['db_uri'] = args.database
6566
kw['auto_create'] = True

preql/api.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ class Preql:
126126

127127
__name__ = "Preql"
128128

129-
def __init__(self, db_uri: str='sqlite://:memory:', print_sql: bool=settings.print_sql, auto_create: bool = False):
129+
def __init__(self, db_uri: str='sqlite://:memory:', print_sql: bool=settings.print_sql,
130+
auto_create: bool = False, autocommit: bool = False
131+
):
130132
"""Initialize a new Preql instance
131133
132134
Parameters:
@@ -137,6 +139,7 @@ def __init__(self, db_uri: str='sqlite://:memory:', print_sql: bool=settings.pri
137139
self._print_sql = print_sql
138140
self._auto_create = auto_create
139141
self._display = display.RichDisplay()
142+
self._autocommit = autocommit
140143
# self.engine.ping()
141144

142145
engine = create_engine(self._db_uri, print_sql=self._print_sql, auto_create=auto_create)
@@ -160,7 +163,7 @@ def set_output_format(self, fmt):
160163
def _reset_interpreter(self, engine=None):
161164
if engine is None:
162165
engine = self._interp.state.db
163-
self._interp = Interpreter(engine, self._display, _preql_inst=self)
166+
self._interp = Interpreter(engine, self._display, _preql_inst=self, autocommit=self._autocommit)
164167

165168
def close(self):
166169
self._interp.state.db.close()

preql/core/autocomplete.py

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def eval_autocomplete(t: ast.Try, go_inside):
3131
with use_scope(scope):
3232
eval_autocomplete(t.catch_block, go_inside)
3333

34+
@dsp
35+
def eval_autocomplete(t: ast.Transaction, go_inside):
36+
eval_autocomplete(t.do, go_inside)
37+
3438
@dsp
3539
def eval_autocomplete(a: ast.InsertRows, go_inside):
3640
eval_autocomplete(a.value, go_inside)

preql/core/evaluate.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,17 @@ def resolve(type_: ast.Type):
102102

103103

104104

105-
def db_query(sql_code, subqueries=None):
105+
def db_query(sql_code, subqueries=None, *, modifies=True):
106106
try:
107-
return get_db().query(sql_code, subqueries)
107+
res = get_db().query(sql_code, subqueries)
108108
except exc.DatabaseQueryError as e:
109109
raise Signal.make(T.DbQueryError, None, e.args[0]) from e
110110

111+
if modifies and context.state.autocommit:
112+
get_db().commit()
113+
114+
return res
115+
111116
def drop_table(table_type):
112117
name = table_type.options['name']
113118
code = sql.compile_drop_table(name)
@@ -346,6 +351,19 @@ def _execute(f: ast.For):
346351
with use_scope({f.var: objects.from_python(i)}):
347352
execute(f.do)
348353

354+
@method
355+
def _execute(t: ast.Transaction):
356+
db_query(sql.BeginTransaction())
357+
with context(state=context.state.set_autocommit(False)):
358+
try:
359+
res = execute(t.do)
360+
except Signal as e:
361+
get_db().rollback()
362+
raise
363+
364+
get_db().commit()
365+
return res
366+
349367
@method
350368
def _execute(t: ast.Try):
351369
try:
@@ -410,7 +428,8 @@ def _execute(t: ast.Throw):
410428
e = evaluate(t.value)
411429
if isinstance(e, ast.Ast):
412430
raise exc.InsufficientAccessLevel()
413-
assert isinstance(e, Exception), e
431+
if not isinstance(e, Exception):
432+
raise Signal.make(T.TypeError, t, f"Can only throw an exception, not {e.type}")
414433
raise e
415434

416435
def execute(stmt):
@@ -870,7 +889,7 @@ def _new_row(new_ast, table, matched):
870889
raise Signal.make(T.TypeError, new_ast, f"'new' expects a persistent table. Instead got a table expression.")
871890

872891
if get_db().target == sql.bigquery:
873-
rowid = db_query(sql.FuncCall(T.string, 'GENERATE_UUID', []))
892+
rowid = db_query(sql.FuncCall(T.string, 'GENERATE_UUID', []), modifies=False)
874893
keys += ['id']
875894
values += [sql.make_value(rowid)]
876895
elif get_db().target == sql.snowflake:
@@ -885,7 +904,7 @@ def _new_row(new_ast, table, matched):
885904
db_query(q)
886905

887906
if get_db().target not in (sql.bigquery, sql.snowflake):
888-
rowid = db_query(sql.LastRowId())
907+
rowid = db_query(sql.LastRowId(), modifies=False)
889908

890909
d = SafeDict({'id': objects.pyvalue_inst(rowid)})
891910
d.update({p.name:v for p, v in matched})
@@ -1015,7 +1034,7 @@ def localize(inst: objects.Instance):
10151034
if inst.code is sql.null:
10161035
return None
10171036

1018-
return db_query(inst.code, inst.subqueries)
1037+
return db_query(inst.code, inst.subqueries, modifies=False)
10191038

10201039
@method
10211040
def localize(inst: objects.ValueInstance):

preql/core/interpreter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def __getattr__(self, attr):
4949

5050

5151
class Interpreter:
52-
def __init__(self, sqlengine, display, use_core=True, _preql_inst=None):
52+
def __init__(self, sqlengine, display, use_core=True, _preql_inst=None, autocommit=False):
5353
assert _preql_inst
5454
self._preql_inst = _preql_inst # XXX temporary hack
5555

56-
self.state = ThreadState.from_components(self, sqlengine, display, initial_namespace())
56+
self.state = ThreadState.from_components(self, sqlengine, display, initial_namespace(), autocommit=autocommit)
5757
if use_core:
5858
mns = import_module(self.state, ast.Import('__builtins__', use_core=False)).namespace
5959
bns = self.state.get_var('__builtins__').namespace

preql/core/parser.py

+1
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def name_path(self, path, name):
278278
while_stmt = ast.While
279279
for_stmt = ast.For
280280
try_catch = ast.Try
281+
transaction = ast.Transaction
281282
one = lambda self, nullable, expr: ast.One(expr, nullable is not None)
282283

283284
def marker(self, _marker):

preql/core/pql_ast.py

+4
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,10 @@ class Try(Statement):
290290
catch_expr: Expr
291291
catch_block: CodeBlock
292292

293+
@dataclass
294+
class Transaction(Statement):
295+
do: CodeBlock
296+
293297
@dataclass
294298
class If(Statement):
295299
cond: Object

preql/core/preql.lark

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ module: _NL? stmt+ -> as_list
1313
| import_stmt
1414
| throw
1515
| try_catch
16+
| transaction
1617
| expr _NL
1718

1819
pql_dict: "{" _ml_sep? ml_list{named_expr} _ml_sep? "}"
@@ -52,6 +53,7 @@ func_def: "func" name func_params "=" expr [_NL string_raw] _NL -> func_def_shor
5253
| "func" name func_params codeblock _NL
5354

5455
try_catch: "try" codeblock _NL? "catch" "(" [name ":"] expr ")" codeblock _NL
56+
transaction: "transaction" codeblock _NL
5557
if_stmt: "if" "(" expr ")" codeblock (_NL? "else" (codeblock _NL|if_stmt) | _NL)
5658
for_stmt: "for" "(" name "in" expr ")" codeblock _NL
5759
while_stmt: "while" "(" expr ")" codeblock _NL

preql/core/sql.py

+7
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,13 @@ def _compile(self, qb):
490490
class SqlStatement(SqlTree):
491491
type = T.nulltype
492492

493+
@dataclass
494+
class BeginTransaction(SqlStatement):
495+
def _compile(self, qb):
496+
if qb.target == mysql:
497+
return ["START TRANSACTION"]
498+
return ["BEGIN TRANSACTION"]
499+
493500
@dataclass
494501
class AddIndex(SqlStatement):
495502
index_name: Id

preql/core/state.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,19 @@ def unique_name(self, obj):
127127

128128

129129
class ThreadState:
130-
def __init__(self, state, ns=None):
130+
def __init__(self, state, ns=None, *, autocommit=False):
131131
self.state = state
132132
# Add logger?
133133

134134
self.ns = Namespace(ns)
135135
self.access_level = AccessLevels.WRITE_DB
136136
self.stacktrace = []
137+
self.autocommit = autocommit
137138

138139
@classmethod
139-
def from_components(cls, interp, db, display, ns=None):
140+
def from_components(cls, interp, db, display, ns=None, *, autocommit):
140141
state = State(interp, db, display)
141-
return cls(state, ns)
142+
return cls(state, ns, autocommit=autocommit)
142143

143144
@property
144145
def interp(self):
@@ -159,6 +160,7 @@ def clone(cls, inst):
159160
s.ns = copy(inst.ns)
160161
s.access_level = inst.access_level
161162
s.stacktrace = copy(inst.stacktrace)
163+
s.autocommit = inst.autocommit
162164
return s
163165

164166

@@ -171,9 +173,15 @@ def reduce_access(self, new_level):
171173
s.access_level = new_level
172174
return s
173175

176+
def set_autocommit(self, autocommit):
177+
s = copy(self)
178+
s.autocommit = autocommit
179+
return s
180+
174181
def require_access(self, level):
175182
if self.access_level < level:
176183
raise InsufficientAccessLevel(level)
184+
177185
def catch_access(self, level):
178186
if self.access_level < level:
179187
raise Exception("Bad access. Security risk.")

tests/test_basic.py

+43
Original file line numberDiff line numberDiff line change
@@ -1564,7 +1564,50 @@ def test_threading(self):
15641564
if p._interp.state.db.target != mysql: # Not supported
15651565
assert p('a{item} - [..100]') == []
15661566

1567+
@uses_tables("a")
1568+
def test_transaction1(self):
1569+
p = self.Preql()
1570+
p('''
1571+
table a {
1572+
x: int
1573+
}
1574+
1575+
try{
1576+
transaction {
1577+
new a(4)
1578+
throw new Exception("Some Error")
1579+
}
1580+
} catch(Exception) {
1581+
}
1582+
1583+
table a {
1584+
x: int
1585+
}
1586+
''')
1587+
1588+
assert not p.a
1589+
1590+
def test_transaction2(self):
1591+
p = self.Preql()
1592+
p.rollback()
1593+
p('''
1594+
table a {
1595+
x: int
1596+
}
1597+
1598+
transaction {
1599+
try{
1600+
new a(5)
1601+
throw new Exception("A")
1602+
} catch(Exception) {
1603+
}
1604+
}
1605+
1606+
''')
15671607

1608+
self.assertEqual( p(r'list(a{x})'), [5] )
1609+
p.run_statement('DROP TABLE a')
1610+
p.commit()
15681611

15691612
class TestFlow(PreqlTests):
15701613
def test_new_freezes_values(self):

0 commit comments

Comments
 (0)