diff --git a/retry.go b/retry.go index 5338985..3d7fd0d 100644 --- a/retry.go +++ b/retry.go @@ -136,38 +136,6 @@ func DoWithData[T any](retryableFunc RetryableFuncWithData[T], opts ...Option) ( return emptyT, err } - // Setting attempts to 0 means we'll retry until we succeed - var lastErr error - if config.attempts == 0 { - for { - t, err := retryableFunc() - if err == nil { - return t, nil - } - - if !IsRecoverable(err) { - return emptyT, err - } - - if !config.retryIf(err) { - return emptyT, err - } - - lastErr = err - - config.onRetry(n, err) - n++ - select { - case <-config.timer.After(delay(config, n, err)): - case <-config.context.Done(): - if config.wrapContextErrorWithLastError { - return emptyT, Error{context.Cause(config.context), lastErr} - } - return emptyT, context.Cause(config.context) - } - } - } - errorLog := Error{} attemptsForError := make(map[error]uint, len(config.attemptsForError)) @@ -184,6 +152,10 @@ func DoWithData[T any](retryableFunc RetryableFuncWithData[T], opts ...Option) ( errorLog = append(errorLog, unpackUnrecoverable(err)) + if !IsRecoverable(err) { + return emptyT, err + } + if !config.retryIf(err) { break } @@ -198,8 +170,9 @@ func DoWithData[T any](retryableFunc RetryableFuncWithData[T], opts ...Option) ( } } + // Setting attempts to 0 means we'll retry until we succeed // if this is last attempt - don't wait - if n == config.attempts-1 { + if config.attempts != 0 && n == config.attempts-1 { break } n++ @@ -213,7 +186,6 @@ func DoWithData[T any](retryableFunc RetryableFuncWithData[T], opts ...Option) ( return emptyT, append(errorLog, context.Cause(config.context)) } - shouldRetry = shouldRetry && n < config.attempts } if config.lastErrorOnly { diff --git a/retry_test.go b/retry_test.go index 1ee3739..edaad38 100644 --- a/retry_test.go +++ b/retry_test.go @@ -104,6 +104,7 @@ func TestRetryIf_ZeroAttempts(t *testing.T) { return err.Error() != "special" }), Delay(time.Nanosecond), + LastErrorOnly(true), Attempts(0), ) assert.Error(t, err) @@ -215,7 +216,6 @@ func TestLastErrorOnly(t *testing.T) { func TestUnrecoverableError(t *testing.T) { attempts := 0 testErr := errors.New("error") - expectedErr := Error{testErr} err := Do( func() error { attempts++ @@ -223,8 +223,8 @@ func TestUnrecoverableError(t *testing.T) { }, Attempts(2), ) - assert.Equal(t, expectedErr, err) - assert.Equal(t, testErr, errors.Unwrap(err)) + assert.Error(t, err) + assert.Equal(t, Unrecoverable(testErr), err) assert.Equal(t, 1, attempts, "unrecoverable error broke the loop") } @@ -457,6 +457,7 @@ func TestContext(t *testing.T) { cancel() } }), + LastErrorOnly(true), Context(ctx), Attempts(0), )