@@ -102,12 +102,17 @@ def resolve(type_: ast.Type):
102
102
103
103
104
104
105
- def db_query (sql_code , subqueries = None ):
105
+ def db_query (sql_code , subqueries = None , * , modifies = True ):
106
106
try :
107
- return get_db ().query (sql_code , subqueries )
107
+ res = get_db ().query (sql_code , subqueries )
108
108
except exc .DatabaseQueryError as e :
109
109
raise Signal .make (T .DbQueryError , None , e .args [0 ]) from e
110
110
111
+ if modifies and context .state .autocommit :
112
+ get_db ().commit ()
113
+
114
+ return res
115
+
111
116
def drop_table (table_type ):
112
117
name = table_type .options ['name' ]
113
118
code = sql .compile_drop_table (name )
@@ -346,6 +351,19 @@ def _execute(f: ast.For):
346
351
with use_scope ({f .var : objects .from_python (i )}):
347
352
execute (f .do )
348
353
354
+ @method
355
+ def _execute (t : ast .Transaction ):
356
+ db_query (sql .BeginTransaction ())
357
+ with context (state = context .state .set_autocommit (False )):
358
+ try :
359
+ res = execute (t .do )
360
+ except Signal as e :
361
+ get_db ().rollback ()
362
+ raise
363
+
364
+ get_db ().commit ()
365
+ return res
366
+
349
367
@method
350
368
def _execute (t : ast .Try ):
351
369
try :
@@ -410,7 +428,8 @@ def _execute(t: ast.Throw):
410
428
e = evaluate (t .value )
411
429
if isinstance (e , ast .Ast ):
412
430
raise exc .InsufficientAccessLevel ()
413
- assert isinstance (e , Exception ), e
431
+ if not isinstance (e , Exception ):
432
+ raise Signal .make (T .TypeError , t , f"Can only throw an exception, not { e .type } " )
414
433
raise e
415
434
416
435
def execute (stmt ):
@@ -870,7 +889,7 @@ def _new_row(new_ast, table, matched):
870
889
raise Signal .make (T .TypeError , new_ast , f"'new' expects a persistent table. Instead got a table expression." )
871
890
872
891
if get_db ().target == sql .bigquery :
873
- rowid = db_query (sql .FuncCall (T .string , 'GENERATE_UUID' , []))
892
+ rowid = db_query (sql .FuncCall (T .string , 'GENERATE_UUID' , []), modifies = False )
874
893
keys += ['id' ]
875
894
values += [sql .make_value (rowid )]
876
895
elif get_db ().target == sql .snowflake :
@@ -885,7 +904,7 @@ def _new_row(new_ast, table, matched):
885
904
db_query (q )
886
905
887
906
if get_db ().target not in (sql .bigquery , sql .snowflake ):
888
- rowid = db_query (sql .LastRowId ())
907
+ rowid = db_query (sql .LastRowId (), modifies = False )
889
908
890
909
d = SafeDict ({'id' : objects .pyvalue_inst (rowid )})
891
910
d .update ({p .name :v for p , v in matched })
@@ -1015,7 +1034,7 @@ def localize(inst: objects.Instance):
1015
1034
if inst .code is sql .null :
1016
1035
return None
1017
1036
1018
- return db_query (inst .code , inst .subqueries )
1037
+ return db_query (inst .code , inst .subqueries , modifies = False )
1019
1038
1020
1039
@method
1021
1040
def localize (inst : objects .ValueInstance ):
0 commit comments