Skip to content

Commit e3685e7

Browse files
committed
Replace parameterized tests with per-connection type tests with transactions.
1 parent 3b01536 commit e3685e7

File tree

1 file changed

+67
-26
lines changed

1 file changed

+67
-26
lines changed

tests/test_export.py

Lines changed: 67 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88

99
Base = declarative_base()
1010
engine = sa.create_engine('postgresql:///copy-test')
11-
connection_types = {
12-
'engine': lambda session: session.connection().engine,
13-
'connection': lambda session: session.connection(),
14-
'raw_connection': lambda session: session.connection().connection,
15-
}
1611

1712
class Album(Base):
1813
__tablename__ = 'album'
@@ -44,79 +39,76 @@ def objects(session):
4439
finally:
4540
engine.execute(Album.__table__.delete())
4641

47-
@pytest.mark.parametrize("conn_type", connection_types.values())
4842
class TestCopyTo:
4943

50-
def test_copy_query(self, session, objects, conn_type):
44+
def test_copy_query(self, session, objects):
5145
sio = io.StringIO()
52-
copy_to(session.query(Album), sio, conn_type(session))
46+
copy_to(session.query(Album), sio, session.connection().engine)
5347
lines = sio.getvalue().strip().split('\n')
5448
assert len(lines) == 3
5549
assert lines[0].split('\t') == [str(objects[0].id), objects[0].name]
5650

57-
def test_copy_table(self, session, objects, conn_type):
51+
def test_copy_table(self, session, objects):
5852
sio = io.StringIO()
59-
copy_to(Album.__table__.select(), sio, conn_type(session))
53+
copy_to(Album.__table__.select(), sio, session.connection().engine)
6054
lines = sio.getvalue().strip().split('\n')
6155
assert len(lines) == 3
6256
assert lines[0].split('\t') == [str(objects[0].id), objects[0].name]
6357

64-
def test_copy_csv(self, session, objects, conn_type):
58+
def test_copy_csv(self, session, objects):
6559
sio = io.StringIO()
6660
flags = {'format': 'csv', 'header': True}
67-
copy_to(session.query(Album), sio, conn_type(session), **flags)
61+
copy_to(session.query(Album), sio, session.connection().engine, **flags)
6862
lines = sio.getvalue().strip().split('\n')
6963
assert len(lines) == 4
7064
assert lines[0].split(',') == ['aid', 'name']
7165
assert lines[1].split(',') == [str(objects[0].id), objects[0].name]
7266

73-
@pytest.mark.parametrize("conn_type", connection_types.values())
7467
class TestCopyRename:
7568

76-
def test_rename_model(self, session, objects, conn_type):
69+
def test_rename_model(self, session, objects):
7770
sio = io.StringIO()
7871
flags = {'format': 'csv', 'header': True}
7972
query = relabel_query(session.query(Album))
80-
copy_to(query, sio, conn_type(session), **flags)
73+
copy_to(query, sio, session.connection().engine, **flags)
8174
lines = sio.getvalue().strip().split('\n')
8275
assert len(lines) == 4
8376
assert lines[0].split(',') == ['id', 'name']
8477
assert lines[1].split(',') == [str(objects[0].id), objects[0].name]
8578

86-
def test_rename_columns(self, session, objects, conn_type):
79+
def test_rename_columns(self, session, objects):
8780
sio = io.StringIO()
8881
flags = {'format': 'csv', 'header': True}
8982
query = relabel_query(session.query(Album.id, Album.name.label('title')))
90-
copy_to(query, sio, conn_type(session), **flags)
83+
copy_to(query, sio, session.connection().engine, **flags)
9184
lines = sio.getvalue().strip().split('\n')
9285
assert len(lines) == 4
9386
assert lines[0].split(',') == ['id', 'title']
9487
assert lines[1].split(',') == [str(objects[0].id), objects[0].name]
9588

96-
@pytest.mark.parametrize("conn_type", connection_types.values())
9789
class TestCopyFrom:
9890

99-
def test_copy_model(self, session, objects, conn_type):
91+
def test_copy_model(self, session, objects):
10092
sio = io.StringIO()
10193
sio.write(u'\t'.join(['4', 'The Works']))
10294
sio.seek(0)
103-
copy_from(sio, Album, conn_type(session))
95+
copy_from(sio, Album, session.connection().engine)
10496
assert session.query(Album).count() == len(objects) + 1
10597
row = session.query(Album).filter_by(id=4).first()
10698
assert row.id == 4
10799
assert row.name == 'The Works'
108100

109-
def test_copy_table(self, session, objects, conn_type):
101+
def test_copy_table(self, session, objects):
110102
sio = io.StringIO()
111103
sio.write(u'\t'.join(['4', 'The Works']))
112104
sio.seek(0)
113-
copy_from(sio, Album.__table__, conn_type(session))
105+
copy_from(sio, Album.__table__, session.connection().engine)
114106
assert session.query(Album).count() == len(objects) + 1
115107
row = session.query(Album).filter_by(id=4).first()
116108
assert row.id == 4
117109
assert row.name == 'The Works'
118110

119-
def test_copy_csv(self, session, objects, conn_type):
111+
def test_copy_csv(self, session, objects):
120112
sio = io.StringIO()
121113
sio.write(
122114
u'\n'.join([
@@ -126,18 +118,67 @@ def test_copy_csv(self, session, objects, conn_type):
126118
)
127119
sio.seek(0)
128120
flags = {'format': 'csv', 'header': True}
129-
copy_from(sio, Album, conn_type(session), **flags)
121+
copy_from(sio, Album, session.connection().engine, **flags)
130122
assert session.query(Album).count() == len(objects) + 1
131123
row = session.query(Album).filter_by(id=4).first()
132124
assert row.id == 4
133125
assert row.name == 'The Works'
134126

135-
def test_copy_columns(self, session, objects, conn_type):
127+
def test_copy_columns(self, session, objects):
136128
sio = io.StringIO()
137129
sio.write(u'\t'.join(['The Works', '4']))
138130
sio.seek(0)
139-
copy_from(sio, Album, conn_type(session), columns=('name', 'aid'))
131+
copy_from(sio, Album, session.connection().engine, columns=('name', 'aid'))
140132
assert session.query(Album).count() == len(objects) + 1
141133
row = session.query(Album).filter_by(id=4).first()
142134
assert row.id == 4
143135
assert row.name == 'The Works'
136+
137+
138+
class TestConnections:
139+
140+
@staticmethod
141+
def _verify_rollback(session, objects):
142+
assert session.query(Album).count() == len(objects)
143+
row = session.query(Album).filter_by(id=4).first()
144+
assert row is None
145+
146+
@staticmethod
147+
def _verify_commit(session, objects):
148+
assert session.query(Album).count() == len(objects) + 1
149+
row = session.query(Album).filter_by(id=4).first()
150+
assert row.id == 4
151+
assert row.name == 'The Works'
152+
153+
def _test_transactions(self, session, conn, sio, objects):
154+
# Test rollback
155+
sio.seek(0)
156+
session.execute('begin;')
157+
copy_from(sio, Album, conn)
158+
session.execute('rollback;')
159+
self._verify_rollback(session, objects)
160+
# Test commit
161+
sio.seek(0)
162+
session.execute('begin;')
163+
copy_from(sio, Album, conn)
164+
session.execute('commit;')
165+
self._verify_commit(session, objects)
166+
167+
def test_engine(self, session, objects):
168+
sio = io.StringIO()
169+
sio.write(u'\t'.join(['4', 'The Works']))
170+
sio.seek(0)
171+
copy_from(sio, Album, session.connection().engine)
172+
self._verify_commit(session, objects)
173+
174+
def test_connection_with_txn(self, session, objects):
175+
sio = io.StringIO()
176+
sio.write(u'\t'.join(['4', 'The Works']))
177+
conn = session.connection()
178+
self._test_transactions(session, conn, sio, objects)
179+
180+
def test_raw_connection_with_txn(self, session, objects):
181+
sio = io.StringIO()
182+
sio.write(u'\t'.join(['4', 'The Works']))
183+
raw_conn = session.connection().connection
184+
self._test_transactions(session, raw_conn, sio, objects)

0 commit comments

Comments
 (0)