@@ -18,10 +18,11 @@ import (
18
18
)
19
19
20
20
type mockConnector struct {
21
- t testing.TB
22
- conns uint32
23
- queryErr error
24
- execErr error
21
+ t testing.TB
22
+ conns uint32
23
+ queryErr error
24
+ execErr error
25
+ commitErr error
25
26
}
26
27
27
28
var _ driver.Connector = & mockConnector {}
@@ -292,3 +293,55 @@ func TestCleanUpResourcesOnPanicInRetryOperation(t *testing.T) {
292
293
})
293
294
})
294
295
}
296
+
297
+ func TestTxDoneErrorReturnsContextError (t * testing.T ) {
298
+ t .Run ("DoTxWithResult" , func (t * testing.T ) {
299
+ ctx , cancel := context .WithCancel (context .Background ())
300
+
301
+ m := & mockConnector {
302
+ t : t ,
303
+ commitErr : sql .ErrTxDone ,
304
+ }
305
+ db := sql .OpenDB (m )
306
+
307
+ attempts := 0
308
+ _ , err := DoTxWithResult (ctx , db ,
309
+ func (ctx context.Context , tx * sql.Tx ) (int , error ) {
310
+ attempts ++
311
+ cancel ()
312
+ time .Sleep (10 * time .Millisecond )
313
+
314
+ return 42 , nil
315
+ },
316
+ )
317
+
318
+ require .Error (t , err )
319
+ require .ErrorIs (t , err , context .Canceled )
320
+ require .Equal (t , 1 , attempts )
321
+ })
322
+
323
+ t .Run ("DoTx" , func (t * testing.T ) {
324
+ ctx , cancel := context .WithCancel (context .Background ())
325
+
326
+ m := & mockConnector {
327
+ t : t ,
328
+ commitErr : sql .ErrTxDone ,
329
+ }
330
+ db := sql .OpenDB (m )
331
+
332
+ attempts := 0
333
+ err := DoTx (ctx , db ,
334
+ func (ctx context.Context , tx * sql.Tx ) error {
335
+ attempts ++
336
+ cancel ()
337
+ time .Sleep (10 * time .Millisecond )
338
+
339
+ return nil
340
+ },
341
+ )
342
+
343
+ require .Error (t , err )
344
+ require .ErrorIs (t , err , context .Canceled )
345
+ require .Equal (t , 1 , attempts )
346
+ })
347
+ }
0 commit comments