8
8
9
9
Base = declarative_base ()
10
10
engine = sa .create_engine ('postgresql:///copy-test' )
11
- connection_types = {
12
- 'engine' : lambda session : session .connection ().engine ,
13
- 'connection' : lambda session : session .connection (),
14
- 'raw_connection' : lambda session : session .connection ().connection ,
15
- }
16
11
17
12
class Album (Base ):
18
13
__tablename__ = 'album'
@@ -44,79 +39,76 @@ def objects(session):
44
39
finally :
45
40
engine .execute (Album .__table__ .delete ())
46
41
47
- @pytest .mark .parametrize ("conn_type" , connection_types .values ())
48
42
class TestCopyTo :
49
43
50
- def test_copy_query (self , session , objects , conn_type ):
44
+ def test_copy_query (self , session , objects ):
51
45
sio = io .StringIO ()
52
- copy_to (session .query (Album ), sio , conn_type ( session ) )
46
+ copy_to (session .query (Album ), sio , session . connection (). engine )
53
47
lines = sio .getvalue ().strip ().split ('\n ' )
54
48
assert len (lines ) == 3
55
49
assert lines [0 ].split ('\t ' ) == [str (objects [0 ].id ), objects [0 ].name ]
56
50
57
- def test_copy_table (self , session , objects , conn_type ):
51
+ def test_copy_table (self , session , objects ):
58
52
sio = io .StringIO ()
59
- copy_to (Album .__table__ .select (), sio , conn_type ( session ) )
53
+ copy_to (Album .__table__ .select (), sio , session . connection (). engine )
60
54
lines = sio .getvalue ().strip ().split ('\n ' )
61
55
assert len (lines ) == 3
62
56
assert lines [0 ].split ('\t ' ) == [str (objects [0 ].id ), objects [0 ].name ]
63
57
64
- def test_copy_csv (self , session , objects , conn_type ):
58
+ def test_copy_csv (self , session , objects ):
65
59
sio = io .StringIO ()
66
60
flags = {'format' : 'csv' , 'header' : True }
67
- copy_to (session .query (Album ), sio , conn_type ( session ) , ** flags )
61
+ copy_to (session .query (Album ), sio , session . connection (). engine , ** flags )
68
62
lines = sio .getvalue ().strip ().split ('\n ' )
69
63
assert len (lines ) == 4
70
64
assert lines [0 ].split (',' ) == ['aid' , 'name' ]
71
65
assert lines [1 ].split (',' ) == [str (objects [0 ].id ), objects [0 ].name ]
72
66
73
- @pytest .mark .parametrize ("conn_type" , connection_types .values ())
74
67
class TestCopyRename :
75
68
76
- def test_rename_model (self , session , objects , conn_type ):
69
+ def test_rename_model (self , session , objects ):
77
70
sio = io .StringIO ()
78
71
flags = {'format' : 'csv' , 'header' : True }
79
72
query = relabel_query (session .query (Album ))
80
- copy_to (query , sio , conn_type ( session ) , ** flags )
73
+ copy_to (query , sio , session . connection (). engine , ** flags )
81
74
lines = sio .getvalue ().strip ().split ('\n ' )
82
75
assert len (lines ) == 4
83
76
assert lines [0 ].split (',' ) == ['id' , 'name' ]
84
77
assert lines [1 ].split (',' ) == [str (objects [0 ].id ), objects [0 ].name ]
85
78
86
- def test_rename_columns (self , session , objects , conn_type ):
79
+ def test_rename_columns (self , session , objects ):
87
80
sio = io .StringIO ()
88
81
flags = {'format' : 'csv' , 'header' : True }
89
82
query = relabel_query (session .query (Album .id , Album .name .label ('title' )))
90
- copy_to (query , sio , conn_type ( session ) , ** flags )
83
+ copy_to (query , sio , session . connection (). engine , ** flags )
91
84
lines = sio .getvalue ().strip ().split ('\n ' )
92
85
assert len (lines ) == 4
93
86
assert lines [0 ].split (',' ) == ['id' , 'title' ]
94
87
assert lines [1 ].split (',' ) == [str (objects [0 ].id ), objects [0 ].name ]
95
88
96
- @pytest .mark .parametrize ("conn_type" , connection_types .values ())
97
89
class TestCopyFrom :
98
90
99
- def test_copy_model (self , session , objects , conn_type ):
91
+ def test_copy_model (self , session , objects ):
100
92
sio = io .StringIO ()
101
93
sio .write (u'\t ' .join (['4' , 'The Works' ]))
102
94
sio .seek (0 )
103
- copy_from (sio , Album , conn_type ( session ) )
95
+ copy_from (sio , Album , session . connection (). engine )
104
96
assert session .query (Album ).count () == len (objects ) + 1
105
97
row = session .query (Album ).filter_by (id = 4 ).first ()
106
98
assert row .id == 4
107
99
assert row .name == 'The Works'
108
100
109
- def test_copy_table (self , session , objects , conn_type ):
101
+ def test_copy_table (self , session , objects ):
110
102
sio = io .StringIO ()
111
103
sio .write (u'\t ' .join (['4' , 'The Works' ]))
112
104
sio .seek (0 )
113
- copy_from (sio , Album .__table__ , conn_type ( session ) )
105
+ copy_from (sio , Album .__table__ , session . connection (). engine )
114
106
assert session .query (Album ).count () == len (objects ) + 1
115
107
row = session .query (Album ).filter_by (id = 4 ).first ()
116
108
assert row .id == 4
117
109
assert row .name == 'The Works'
118
110
119
- def test_copy_csv (self , session , objects , conn_type ):
111
+ def test_copy_csv (self , session , objects ):
120
112
sio = io .StringIO ()
121
113
sio .write (
122
114
u'\n ' .join ([
@@ -126,18 +118,67 @@ def test_copy_csv(self, session, objects, conn_type):
126
118
)
127
119
sio .seek (0 )
128
120
flags = {'format' : 'csv' , 'header' : True }
129
- copy_from (sio , Album , conn_type ( session ) , ** flags )
121
+ copy_from (sio , Album , session . connection (). engine , ** flags )
130
122
assert session .query (Album ).count () == len (objects ) + 1
131
123
row = session .query (Album ).filter_by (id = 4 ).first ()
132
124
assert row .id == 4
133
125
assert row .name == 'The Works'
134
126
135
- def test_copy_columns (self , session , objects , conn_type ):
127
+ def test_copy_columns (self , session , objects ):
136
128
sio = io .StringIO ()
137
129
sio .write (u'\t ' .join (['The Works' , '4' ]))
138
130
sio .seek (0 )
139
- copy_from (sio , Album , conn_type ( session ) , columns = ('name' , 'aid' ))
131
+ copy_from (sio , Album , session . connection (). engine , columns = ('name' , 'aid' ))
140
132
assert session .query (Album ).count () == len (objects ) + 1
141
133
row = session .query (Album ).filter_by (id = 4 ).first ()
142
134
assert row .id == 4
143
135
assert row .name == 'The Works'
136
+
137
+
138
+ class TestConnections :
139
+
140
+ @staticmethod
141
+ def _verify_rollback (session , objects ):
142
+ assert session .query (Album ).count () == len (objects )
143
+ row = session .query (Album ).filter_by (id = 4 ).first ()
144
+ assert row is None
145
+
146
+ @staticmethod
147
+ def _verify_commit (session , objects ):
148
+ assert session .query (Album ).count () == len (objects ) + 1
149
+ row = session .query (Album ).filter_by (id = 4 ).first ()
150
+ assert row .id == 4
151
+ assert row .name == 'The Works'
152
+
153
+ def _test_transactions (self , session , conn , sio , objects ):
154
+ # Test rollback
155
+ sio .seek (0 )
156
+ session .execute ('begin;' )
157
+ copy_from (sio , Album , conn )
158
+ session .execute ('rollback;' )
159
+ self ._verify_rollback (session , objects )
160
+ # Test commit
161
+ sio .seek (0 )
162
+ session .execute ('begin;' )
163
+ copy_from (sio , Album , conn )
164
+ session .execute ('commit;' )
165
+ self ._verify_commit (session , objects )
166
+
167
+ def test_engine (self , session , objects ):
168
+ sio = io .StringIO ()
169
+ sio .write (u'\t ' .join (['4' , 'The Works' ]))
170
+ sio .seek (0 )
171
+ copy_from (sio , Album , session .connection ().engine )
172
+ self ._verify_commit (session , objects )
173
+
174
+ def test_connection_with_txn (self , session , objects ):
175
+ sio = io .StringIO ()
176
+ sio .write (u'\t ' .join (['4' , 'The Works' ]))
177
+ conn = session .connection ()
178
+ self ._test_transactions (session , conn , sio , objects )
179
+
180
+ def test_raw_connection_with_txn (self , session , objects ):
181
+ sio = io .StringIO ()
182
+ sio .write (u'\t ' .join (['4' , 'The Works' ]))
183
+ raw_conn = session .connection ().connection
184
+ self ._test_transactions (session , raw_conn , sio , objects )
0 commit comments