diff --git a/circuitbreaker/README.md b/circuitbreaker/README.md new file mode 100644 index 0000000..e15b617 --- /dev/null +++ b/circuitbreaker/README.md @@ -0,0 +1,109 @@ +# Echo Circuit Breaker Middleware + +A robust circuit breaker implementation for the [Echo](https://echo.labstack.com/) web framework, providing fault tolerance and graceful service degradation. + +## Overview + +The Circuit Breaker pattern helps prevent cascading failures in distributed systems. When dependencies fail or become slow, the circuit breaker "trips" and fails fast, preventing system overload and allowing time for recovery. + +This middleware implements a full-featured circuit breaker with three states: + +- **Closed** (normal operation): Requests flow through normally +- **Open** (failure mode): Requests are rejected immediately without reaching the protected service +- **Half-Open** (recovery testing): Limited requests are allowed through to test if the service has recovered + +## Features + +- Configurable failure and success thresholds +- Automatic state transitions based on error rates +- Customizable timeout periods +- Controlled recovery with half-open state limiting concurrent requests +- Support for custom failure detection +- State transition callbacks +- Comprehensive metrics and monitoring + +## Installation + +```bash +go get github.com/labstack/echo-contrib/circuitbreaker +``` + +### Basic Usage + +```go +package main + +import ( + "github.com/labstack/echo/v4" + "github.com/labstack/echo-contrib//circuitbreaker" +) + +func main() { + e := echo.New() + + // Create a circuit breaker with default configuration + cb := circuitbreaker.New(circuitbreaker.DefaultConfig) + + // Apply it to specific routes + e.GET("/protected", protectedHandler, circuitbreaker.Middleware(cb)) + + e.Start(":8080") +} + +func protectedHandler(c echo.Context) error { + // Your handler code here + return c.String(200, "Service is healthy") +} +``` + +### Advanced Usage + +```go +cb := circuitbreaker.New(circuitbreaker.Config{ + FailureThreshold: 10, // Number of failures before circuit opens + Timeout: 30 * time.Second, // How long circuit stays open + SuccessThreshold: 2, // Successes needed to close circuit + HalfOpenMaxConcurrent: 5, // Max concurrent requests in half-open state + IsFailure: func(c echo.Context, err error) bool { + // Custom failure detection logic + return err != nil || c.Response().Status >= 500 + }, + OnOpen: func(c echo.Context) error { + // Custom handling when circuit opens + return c.JSON(503, map[string]string{"status": "service temporarily unavailable"}) + }, +}) +``` + +### Monitoring and Metrics + +```go +// Get current metrics +metrics := cb.Metrics() + +// Add a metrics endpoint +e.GET("/metrics/circuit", func(c echo.Context) error { + return c.JSON(200, cb.GetStateStats()) +}) +``` + +### State Management + +```go +// Force circuit open (for maintenance, etc.) +cb.ForceOpen() + +// Force circuit closed +cb.ForceClose() + +// Reset circuit to initial state +cb.Reset() +``` + +### Best Practices + +1. Use circuit breakers for external dependencies: APIs, databases, etc. +2. Set appropriate thresholds: Too low may cause premature circuit opening, too high may not protect effectively +3. Monitor circuit state: Add logging/metrics for state transitions +4. Consider service degradation: Provide fallbacks when circuit is open +5. Set timeouts appropriately: Match timeout to expected recovery time diff --git a/circuitbreaker/circuit_breaker.go b/circuitbreaker/circuit_breaker.go new file mode 100644 index 0000000..50f78fb --- /dev/null +++ b/circuitbreaker/circuit_breaker.go @@ -0,0 +1,445 @@ +package circuitbreaker + +import ( + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/labstack/echo/v4" +) + +// State represents the state of the circuit breaker +type State string + +const ( + StateClosed State = "closed" // Normal operation + StateOpen State = "open" // Requests are blocked + StateHalfOpen State = "half-open" // Limited requests allowed to check recovery +) + +// MetricsData represents circuit breaker metrics with proper types +type MetricsData struct { + State State `json:"state"` + Failures int64 `json:"failures"` + Successes int64 `json:"successes"` + TotalRequests int64 `json:"totalRequests"` + RejectedRequests int64 `json:"rejectedRequests"` +} + +// Config holds the configurable parameters +type Config struct { + // Failure threshold to trip the circuit + FailureThreshold int + // Duration circuit stays open before allowing test requests + Timeout time.Duration + // Success threshold to close the circuit from half-open + SuccessThreshold int + // Maximum concurrent requests allowed in half-open state + HalfOpenMaxConcurrent int64 + // Custom failure detector function (return true if response should count as failure) + IsFailure func(c echo.Context, err error) bool + // Callbacks for state transitions + OnOpen func(echo.Context) error // Called when circuit opens + OnHalfOpen func(echo.Context) error // Called when circuit transitions to half-open + OnClose func(echo.Context) error // Called when circuit closes +} + +// DefaultConfig provides sensible defaults for the circuit breaker +var DefaultConfig = Config{ + FailureThreshold: 5, + Timeout: 5 * time.Second, + SuccessThreshold: 1, + HalfOpenMaxConcurrent: 1, + IsFailure: func(c echo.Context, err error) bool { + return err != nil || c.Response().Status >= http.StatusInternalServerError + }, + OnOpen: func(c echo.Context) error { + return c.JSON(http.StatusServiceUnavailable, map[string]interface{}{ + "error": "service unavailable", + }) + }, + OnHalfOpen: func(c echo.Context) error { + return c.JSON(http.StatusTooManyRequests, map[string]interface{}{ + "error": "service under recovery", + }) + }, + OnClose: func(c echo.Context) error { + return nil + }, +} + +// HalfOpenLimiter manages concurrent requests in half-open state +type HalfOpenLimiter struct { + maxConcurrent int64 + current atomic.Int64 +} + +// NewHalfOpenLimiter creates a new limiter for half-open state +func NewHalfOpenLimiter(maxConcurrent int64) *HalfOpenLimiter { + return &HalfOpenLimiter{ + maxConcurrent: maxConcurrent, + } +} + +// TryAcquire attempts to acquire a slot in the limiter +// Returns true if successful, false if at capacity +func (l *HalfOpenLimiter) TryAcquire() bool { + for { + current := l.current.Load() + if current >= l.maxConcurrent { + return false + } + + if l.current.CompareAndSwap(current, current+1) { + return true + } + } +} + +// Release releases a previously acquired slot +func (l *HalfOpenLimiter) Release() { + current := l.current.Load() + if current > 0 { + l.current.CompareAndSwap(current, current-1) + } +} + +// CircuitBreaker implements the circuit breaker pattern +type CircuitBreaker struct { + failureCount atomic.Int64 // Count of failures + successCount atomic.Int64 // Count of successes in half-open state + totalRequests atomic.Int64 // Count of total requests + rejectedRequests atomic.Int64 // Count of rejected requests + state State // Current state of circuit breaker + mutex sync.RWMutex // Protects state transitions + failureThreshold int // Max failures before opening circuit + timeout time.Duration // Duration to stay open before transitioning to half-open + successThreshold int // Successes required to close circuit + openUntil atomic.Int64 // Unix timestamp (nanos) when open state expires + config Config // Configuration settings + now func() time.Time // Function for getting current time (useful for testing) + halfOpenLimiter *HalfOpenLimiter // Controls limited requests in half-open state + lastStateChange time.Time // Time of last state change +} + +// New initializes a circuit breaker with the given configuration +func New(config Config) *CircuitBreaker { + // Apply default values for zero values + if config.FailureThreshold <= 0 { + config.FailureThreshold = DefaultConfig.FailureThreshold + } + if config.Timeout <= 0 { + config.Timeout = DefaultConfig.Timeout + } + if config.SuccessThreshold <= 0 { + config.SuccessThreshold = DefaultConfig.SuccessThreshold + } + if config.HalfOpenMaxConcurrent <= 0 { + config.HalfOpenMaxConcurrent = DefaultConfig.HalfOpenMaxConcurrent + } + if config.IsFailure == nil { + config.IsFailure = DefaultConfig.IsFailure + } + if config.OnOpen == nil { + config.OnOpen = DefaultConfig.OnOpen + } + if config.OnHalfOpen == nil { + config.OnHalfOpen = DefaultConfig.OnHalfOpen + } + if config.OnClose == nil { + config.OnClose = DefaultConfig.OnClose + } + + now := time.Now() + + return &CircuitBreaker{ + failureThreshold: config.FailureThreshold, + timeout: config.Timeout, + successThreshold: config.SuccessThreshold, + state: StateClosed, + config: config, + now: time.Now, + halfOpenLimiter: NewHalfOpenLimiter(config.HalfOpenMaxConcurrent), + lastStateChange: now, + } +} + +// GetState returns the current state of the circuit breaker +func (cb *CircuitBreaker) State() State { + // Check for auto-transition from open to half-open based on timestamp + if cb.state == StateOpen { + openUntil := cb.openUntil.Load() + if openUntil > 0 && cb.now().UnixNano() >= openUntil { + cb.transitionToHalfOpen() + } + } + + cb.mutex.RLock() + defer cb.mutex.RUnlock() + return cb.state +} + +// IsOpen returns true if the circuit is open +func (cb *CircuitBreaker) IsOpen() bool { + return cb.State() == StateOpen +} + +// Reset resets the circuit breaker to its initial closed state +func (cb *CircuitBreaker) Reset() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + // Reset counters + cb.failureCount.Store(0) + cb.successCount.Store(0) + + // Reset state + cb.state = StateClosed + cb.lastStateChange = cb.now() + cb.openUntil.Store(0) +} + +// ForceOpen forcibly opens the circuit regardless of failure count +func (cb *CircuitBreaker) ForceOpen() { + cb.transitionToOpen() +} + +// ForceClose forcibly closes the circuit regardless of current state +func (cb *CircuitBreaker) ForceClose() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + cb.state = StateClosed + cb.lastStateChange = cb.now() + cb.failureCount.Store(0) + cb.successCount.Store(0) + cb.openUntil.Store(0) +} + +// SetTimeout updates the timeout duration +func (cb *CircuitBreaker) SetTimeout(timeout time.Duration) { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + cb.timeout = timeout +} + +// transitionToOpen changes state to open and sets timestamp for half-open transition +func (cb *CircuitBreaker) transitionToOpen() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + if cb.state == StateOpen { + return + } + + cb.state = StateOpen + cb.lastStateChange = cb.now() + + // Set timestamp when the circuit should transition to half-open + openUntil := cb.now().Add(cb.timeout).UnixNano() + cb.openUntil.Store(openUntil) + + // Reset failure counter + cb.failureCount.Store(0) +} + +// transitionToHalfOpen changes state from open to half-open +func (cb *CircuitBreaker) transitionToHalfOpen() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + if cb.state == StateOpen { + cb.state = StateHalfOpen + cb.lastStateChange = cb.now() + + // Reset counters + cb.failureCount.Store(0) + cb.successCount.Store(0) + cb.openUntil.Store(0) + } +} + +// transitionToClosed changes state from half-open to closed +func (cb *CircuitBreaker) transitionToClosed() { + cb.mutex.Lock() + defer cb.mutex.Unlock() + + if cb.state != StateHalfOpen { + return + } + + // Transition to closed state + cb.state = StateClosed + cb.lastStateChange = cb.now() + + // Reset counters + cb.failureCount.Store(0) + cb.successCount.Store(0) + +} + +// AllowRequest determines if a request is allowed based on circuit state +func (cb *CircuitBreaker) AllowRequest() (bool, State) { + cb.totalRequests.Add(1) + + // First check if we need to transition from open to half-open + // Use a single lock section to check and potentially transition + cb.mutex.Lock() + currentState := cb.state + if currentState == StateOpen { + openUntil := cb.openUntil.Load() + if openUntil > 0 && cb.now().UnixNano() >= openUntil { + // Use the existing transition method instead of duplicating logic + cb.state = StateHalfOpen + cb.lastStateChange = cb.now() + cb.failureCount.Store(0) + cb.successCount.Store(0) + cb.openUntil.Store(0) + currentState = StateHalfOpen + } + } + + // Determine if the request is allowed based on the current state + var allowed bool + switch currentState { + case StateOpen: // Block all requests + allowed = false + case StateHalfOpen: // Allow limited requests + // Check if we can acquire a slot in the half-open limiter + // Use TryAcquire to avoid blocking + // This is a non-blocking call, so it won't wait for a slot + // to become available + // If the limit is reached, we return false + // and increment the rejected requests counter + allowed = cb.halfOpenLimiter.TryAcquire() + default: // StateClosed + allowed = true + } + cb.mutex.Unlock() + + if !allowed { + cb.rejectedRequests.Add(1) + } + + return allowed, currentState +} + +// ReleaseHalfOpen releases a slot in the half-open limiter +func (cb *CircuitBreaker) ReleaseHalfOpen() { + if cb.State() == StateHalfOpen { + cb.halfOpenLimiter.Release() + } +} + +// ReportSuccess increments success count and closes circuit if threshold met +func (cb *CircuitBreaker) ReportSuccess() { + if cb.State() == StateHalfOpen { + newSuccessCount := cb.successCount.Add(1) + if int(newSuccessCount) >= cb.successThreshold { + cb.transitionToClosed() + } + } +} + +// ReportFailure increments failure count and opens circuit if threshold met +func (cb *CircuitBreaker) ReportFailure() { + state := cb.State() + + switch state { + case StateHalfOpen: + // In half-open, a single failure trips the circuit + cb.transitionToOpen() + case StateClosed: + newFailureCount := cb.failureCount.Add(1) + if int(newFailureCount) >= cb.failureThreshold { + cb.transitionToOpen() + } + } +} + +// Metrics returns basic metrics about the circuit breaker +func (cb *CircuitBreaker) Metrics() MetricsData { + return MetricsData{ + State: cb.State(), + Failures: cb.failureCount.Load(), + Successes: cb.successCount.Load(), + TotalRequests: cb.totalRequests.Load(), + RejectedRequests: cb.rejectedRequests.Load(), + } +} + +// GetStateStats returns detailed statistics about the circuit breaker +func (cb *CircuitBreaker) GetStateStats() map[string]interface{} { + state := cb.State() + openUntil := cb.openUntil.Load() + + stats := map[string]interface{}{ + "state": state, + "failures": cb.failureCount.Load(), + "successes": cb.successCount.Load(), + "totalRequests": cb.totalRequests.Load(), + "rejectedRequests": cb.rejectedRequests.Load(), + "lastStateChange": cb.lastStateChange, + "openDuration": cb.timeout, + "failureThreshold": cb.failureThreshold, + "successThreshold": cb.successThreshold, + } + + if openUntil > 0 { + stats["openUntil"] = time.Unix(0, openUntil) + } + + return stats +} + +// Middleware wraps the echo handler with circuit breaker logic +func Middleware(cb *CircuitBreaker) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + + if cb == nil { + return next(c) + } + + allowed, state := cb.AllowRequest() + + if !allowed { + // Call appropriate callback based on state + if state == StateHalfOpen && cb.config.OnHalfOpen != nil { + return cb.config.OnHalfOpen(c) + } else if state == StateOpen && cb.config.OnOpen != nil { + return cb.config.OnOpen(c) + } + return c.NoContent(http.StatusServiceUnavailable) + } + + // If request allowed in half-open state, ensure limiter is released + halfOpen := state == StateHalfOpen + if halfOpen { + defer cb.ReleaseHalfOpen() + } + + // Execute the request + err := next(c) + + // Check if the response should be considered a failure + if cb.config.IsFailure(c, err) { + cb.ReportFailure() + } else { + cb.ReportSuccess() + + // If transition to closed state just happened, trigger callback + if halfOpen && cb.State() == StateClosed && cb.config.OnClose != nil { + if closeErr := cb.config.OnClose(c); closeErr != nil { + // Log the error but don't override the actual response + c.Logger().Errorf("Circuit breaker OnClose callback error: %v", closeErr) + } + } + } + + return err + } + } +} diff --git a/circuitbreaker/circuit_breaker_test.go b/circuitbreaker/circuit_breaker_test.go new file mode 100644 index 0000000..8423bf0 --- /dev/null +++ b/circuitbreaker/circuit_breaker_test.go @@ -0,0 +1,323 @@ +package circuitbreaker + +import ( + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +// TestCircuitBreakerBasicOperations tests basic operations of the circuit breaker +func TestCircuitBreakerBasicOperations(t *testing.T) { + // Create circuit breaker with custom config + cb := New(Config{ + FailureThreshold: 3, + Timeout: 100 * time.Millisecond, + SuccessThreshold: 2, + HalfOpenMaxConcurrent: 2, + }) + + // Initial state should be closed + assert.Equal(t, StateClosed, cb.State()) + assert.False(t, cb.IsOpen()) + + // Test state transitions + t.Run("State transitions", func(t *testing.T) { + // Reporting failures should eventually open the circuit + for i := 0; i < 3; i++ { + cb.ReportFailure() + } + assert.Equal(t, StateOpen, cb.State()) + assert.True(t, cb.IsOpen()) + + // Requests should be rejected in open state + allowed, state := cb.AllowRequest() + assert.False(t, allowed) + assert.Equal(t, StateOpen, state) + + // Wait for timeout to transition to half-open + time.Sleep(150 * time.Millisecond) + assert.Equal(t, StateHalfOpen, cb.State()) + + // Some requests should be allowed in half-open state + allowed, state = cb.AllowRequest() + assert.True(t, allowed) + assert.Equal(t, StateHalfOpen, state) + + // Report successes to close the circuit + cb.ReportSuccess() + cb.ReportSuccess() + assert.Equal(t, StateClosed, cb.State()) + }) + + // Reset the circuit breaker + cb.Reset() + assert.Equal(t, StateClosed, cb.State()) + + t.Run("Force state changes", func(t *testing.T) { + // Force open + cb.ForceOpen() + assert.Equal(t, StateOpen, cb.State()) + + // Force close + cb.ForceClose() + assert.Equal(t, StateClosed, cb.State()) + }) +} + +// TestCircuitBreakerHalfOpenConcurrency tests the half-open state with concurrency +func TestCircuitBreakerHalfOpenConcurrency(t *testing.T) { + // Create circuit breaker that allows 2 concurrent requests in half-open + cb := New(Config{ + FailureThreshold: 1, + Timeout: 100 * time.Millisecond, + SuccessThreshold: 2, + HalfOpenMaxConcurrent: 2, + }) + + // Force into half-open state + cb.ForceOpen() + time.Sleep(150 * time.Millisecond) + assert.Equal(t, StateHalfOpen, cb.State()) + + // First two requests should be allowed + allowed1, _ := cb.AllowRequest() + allowed2, _ := cb.AllowRequest() + assert.True(t, allowed1) + assert.True(t, allowed2) + + // Third request should be rejected + allowed3, _ := cb.AllowRequest() + assert.False(t, allowed3) + + // After releasing one slot, a new request should be allowed + cb.ReleaseHalfOpen() + allowed4, _ := cb.AllowRequest() + assert.True(t, allowed4) +} + +// TestCircuitBreakerConcurrency tests the concurrency safety of the circuit breaker +func TestCircuitBreakerConcurrency(t *testing.T) { + cb := New(Config{ + FailureThreshold: 5, + Timeout: 100 * time.Millisecond, + SuccessThreshold: 3, + HalfOpenMaxConcurrent: 2, + }) + + // Test concurrent requests + t.Run("Concurrent requests", func(t *testing.T) { + var wg sync.WaitGroup + numRequests := 100 + + wg.Add(numRequests) + for i := 0; i < numRequests; i++ { + go func(i int) { + defer wg.Done() + allowed, _ := cb.AllowRequest() + if allowed && i%2 == 0 { + cb.ReportSuccess() + } else if allowed { + cb.ReportFailure() + } + }(i) + } + + wg.Wait() + metrics := cb.Metrics() + assert.Equal(t, int64(numRequests), metrics.TotalRequests) + }) +} + +// TestCircuitBreakerMetrics checks the metrics of the circuit breaker +func TestCircuitBreakerMetrics(t *testing.T) { + cb := New(DefaultConfig) + + // Report some activities + cb.ReportFailure() + allowed, _ := cb.AllowRequest() + assert.True(t, allowed) + cb.ReportSuccess() + + // Check basic metrics + metrics := cb.Metrics() + assert.Equal(t, int64(1), metrics.Failures) + assert.Equal(t, int64(1), metrics.TotalRequests) + assert.Equal(t, StateClosed, metrics.State) + + // Check detailed stats + stats := cb.GetStateStats() + assert.Equal(t, DefaultConfig.FailureThreshold, stats["failureThreshold"]) + assert.Equal(t, DefaultConfig.SuccessThreshold, stats["successThreshold"]) + assert.Equal(t, DefaultConfig.Timeout, stats["openDuration"]) +} + +// TestMiddleware tests the middleware functionality +func TestMiddleware(t *testing.T) { + // Setup + e := echo.New() + cb := New(DefaultConfig) + + // Create a test handler that can be configured to succeed or fail + shouldFail := false + testHandler := func(c echo.Context) error { + if shouldFail { + return echo.NewHTTPError(http.StatusInternalServerError, "test error") + } + return c.String(http.StatusOK, "success") + } + + // Apply middleware + handler := Middleware(cb)(testHandler) + + t.Run("Success case", func(t *testing.T) { + // Create request and recorder + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Execute request + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + // Check metrics + metrics := cb.Metrics() + assert.Equal(t, int64(1), metrics.TotalRequests) + assert.Equal(t, int64(0), metrics.Failures) + }) + + t.Run("Failure case", func(t *testing.T) { + // Configure handler to fail + shouldFail = true + + // Create request and recorder + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Execute request and expect error (which middleware passes through) + err := handler(c) + assert.Error(t, err) + + // Check metrics - failures should be incremented + metrics := cb.Metrics() + assert.Equal(t, int64(2), metrics.TotalRequests) + assert.Equal(t, int64(1), metrics.Failures) + + // Force more failures to open the circuit + for i := 0; i < DefaultConfig.FailureThreshold-1; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + _ = handler(c) + } + + // Circuit should now be open + assert.Equal(t, StateOpen, cb.State()) + + // Requests should be rejected + req = httptest.NewRequest(http.MethodGet, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + err = handler(c) + assert.NoError(t, err) // OnOpen callback handles the response + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) + }) +} + +// TestCustomCallbacks tests custom callbacks +func TestCustomCallbacks(t *testing.T) { + callbackInvoked := false + + cb := New(Config{ + FailureThreshold: 1, + Timeout: 100 * time.Millisecond, + SuccessThreshold: 1, + HalfOpenMaxConcurrent: 1, + OnOpen: func(c echo.Context) error { + callbackInvoked = true + return c.JSON(http.StatusServiceUnavailable, map[string]string{ + "error": "circuit open", + "status": "unavailable", + }) + }, + }) + + // Setup Echo + e := echo.New() + testHandler := func(c echo.Context) error { + return errors.New("some error") + } + handler := Middleware(cb)(testHandler) + + // First request opens the circuit + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + _ = handler(c) + + // Second request should invoke the callback + req = httptest.NewRequest(http.MethodGet, "/", nil) + rec = httptest.NewRecorder() + c = e.NewContext(req, rec) + _ = handler(c) + + assert.True(t, callbackInvoked) + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) + assert.Contains(t, rec.Body.String(), "circuit open") +} + +// TestErrorHandling tests error handling in callbacks +func TestErrorHandling(t *testing.T) { + errorCalled := false + + // Create a logger that captures errors + e := echo.New() + e.Logger.SetOutput(new(testLogWriter)) + + // Create circuit breaker with callbacks that return errors + cb := New(Config{ + FailureThreshold: 1, + Timeout: 100 * time.Millisecond, + SuccessThreshold: 1, + HalfOpenMaxConcurrent: 1, + OnClose: func(c echo.Context) error { + errorCalled = true + return errors.New("test error in callback") + }, + }) + + // Force into half-open state + cb.ForceOpen() + time.Sleep(150 * time.Millisecond) + + // Create handler + testHandler := func(c echo.Context) error { + return nil // Success + } + handler := Middleware(cb)(testHandler) + + // Execute request to trigger transition to closed + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + err := handler(c) + + // The error should be logged but not returned + assert.NoError(t, err) + assert.True(t, errorCalled) + assert.Equal(t, http.StatusOK, rec.Code) +} + +// Helper type for capturing logs +type testLogWriter struct{} + +func (w *testLogWriter) Write(p []byte) (n int, err error) { + return len(p), nil +}