diff --git a/auth/reauth_credentials_listener.go b/auth/reauth_credentials_listener.go index 40076a0b1..f4b319838 100644 --- a/auth/reauth_credentials_listener.go +++ b/auth/reauth_credentials_listener.go @@ -44,4 +44,4 @@ func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, on } // Ensure ReAuthCredentialsListener implements the CredentialsListener interface. -var _ CredentialsListener = (*ReAuthCredentialsListener)(nil) +var _ CredentialsListener = (*ReAuthCredentialsListener)(nil) \ No newline at end of file diff --git a/error.go b/error.go index 8013de44a..7273313b5 100644 --- a/error.go +++ b/error.go @@ -108,10 +108,12 @@ func isRedisError(err error) bool { func isBadConn(err error, allowTimeout bool, addr string) bool { switch err { - case nil: - return false - case context.Canceled, context.DeadlineExceeded: - return true + case nil: + return false + case context.Canceled, context.DeadlineExceeded: + return true + case pool.ErrConnUnusableTimeout: + return true } if isRedisError(err) { diff --git a/internal/auth/streaming/conn_reauth_credentials_listener.go b/internal/auth/streaming/conn_reauth_credentials_listener.go new file mode 100644 index 000000000..22bfedf71 --- /dev/null +++ b/internal/auth/streaming/conn_reauth_credentials_listener.go @@ -0,0 +1,100 @@ +package streaming + +import ( + "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/internal/pool" +) + +// ConnReAuthCredentialsListener is a credentials listener for a specific connection +// that triggers re-authentication when credentials change. +// +// This listener implements the auth.CredentialsListener interface and is subscribed +// to a StreamingCredentialsProvider. When new credentials are received via OnNext, +// it marks the connection for re-authentication through the manager. +// +// The re-authentication is always performed asynchronously to avoid blocking the +// credentials provider and to prevent potential deadlocks with the pool semaphore. +// The actual re-auth happens when the connection is returned to the pool in an idle state. +// +// Lifecycle: +// - Created during connection initialization via Manager.Listener() +// - Subscribed to the StreamingCredentialsProvider +// - Receives credential updates via OnNext() +// - Cleaned up when connection is removed from pool via Manager.RemoveListener() +type ConnReAuthCredentialsListener struct { + // reAuth is the function to re-authenticate the connection with new credentials + reAuth func(conn *pool.Conn, credentials auth.Credentials) error + + // onErr is the function to call when re-authentication or acquisition fails + onErr func(conn *pool.Conn, err error) + + // conn is the connection this listener is associated with + conn *pool.Conn + + // manager is the streaming credentials manager for coordinating re-auth + manager *Manager +} + +// OnNext is called when new credentials are received from the StreamingCredentialsProvider. +// +// This method marks the connection for asynchronous re-authentication. The actual +// re-authentication happens in the background when the connection is returned to the +// pool and is in an idle state. +// +// Asynchronous re-auth is used to: +// - Avoid blocking the credentials provider's notification goroutine +// - Prevent deadlocks with the pool's semaphore (especially with small pool sizes) +// - Ensure re-auth happens when the connection is safe to use (not processing commands) +// +// The reAuthFn callback receives: +// - nil if the connection was successfully acquired for re-auth +// - error if acquisition timed out or failed +// +// Thread-safe: Called by the credentials provider's notification goroutine. +func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) { + if c.conn == nil || c.conn.IsClosed() || c.manager == nil || c.reAuth == nil { + return + } + + // Always use async reauth to avoid complex pool semaphore issues + // The synchronous path can cause deadlocks in the pool's semaphore mechanism + // when called from the Subscribe goroutine, especially with small pool sizes. + // The connection pool hook will re-authenticate the connection when it is + // returned to the pool in a clean, idle state. + c.manager.MarkForReAuth(c.conn, func(err error) { + // err is from connection acquisition (timeout, etc.) + if err != nil { + // Log the error + c.OnError(err) + return + } + // err is from reauth command execution + err = c.reAuth(c.conn, credentials) + if err != nil { + // Log the error + c.OnError(err) + return + } + }) +} + +// OnError is called when an error occurs during credential streaming or re-authentication. +// +// This method can be called from: +// - The StreamingCredentialsProvider when there's an error in the credentials stream +// - The re-auth process when connection acquisition times out +// - The re-auth process when the AUTH command fails +// +// The error is delegated to the onErr callback provided during listener creation. +// +// Thread-safe: Can be called from multiple goroutines (provider, re-auth worker). +func (c *ConnReAuthCredentialsListener) OnError(err error) { + if c.onErr == nil { + return + } + + c.onErr(c.conn, err) +} + +// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface. +var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil) diff --git a/internal/auth/streaming/cred_listeners.go b/internal/auth/streaming/cred_listeners.go new file mode 100644 index 000000000..66e6eafdc --- /dev/null +++ b/internal/auth/streaming/cred_listeners.go @@ -0,0 +1,77 @@ +package streaming + +import ( + "sync" + + "github.com/redis/go-redis/v9/auth" +) + +// CredentialsListeners is a thread-safe collection of credentials listeners +// indexed by connection ID. +// +// This collection is used by the Manager to maintain a registry of listeners +// for each connection in the pool. Listeners are reused when connections are +// reinitialized (e.g., after a handoff) to avoid creating duplicate subscriptions +// to the StreamingCredentialsProvider. +// +// The collection supports concurrent access from multiple goroutines during +// connection initialization, credential updates, and connection removal. +type CredentialsListeners struct { + // listeners maps connection ID to credentials listener + listeners map[uint64]auth.CredentialsListener + + // lock protects concurrent access to the listeners map + lock sync.RWMutex +} + +// NewCredentialsListeners creates a new thread-safe credentials listeners collection. +func NewCredentialsListeners() *CredentialsListeners { + return &CredentialsListeners{ + listeners: make(map[uint64]auth.CredentialsListener), + } +} + +// Add adds or updates a credentials listener for a connection. +// +// If a listener already exists for the connection ID, it is replaced. +// This is safe because the old listener should have been unsubscribed +// before the connection was reinitialized. +// +// Thread-safe: Can be called concurrently from multiple goroutines. +func (c *CredentialsListeners) Add(connID uint64, listener auth.CredentialsListener) { + c.lock.Lock() + defer c.lock.Unlock() + if c.listeners == nil { + c.listeners = make(map[uint64]auth.CredentialsListener) + } + c.listeners[connID] = listener +} + +// Get retrieves the credentials listener for a connection. +// +// Returns: +// - listener: The credentials listener for the connection, or nil if not found +// - ok: true if a listener exists for the connection ID, false otherwise +// +// Thread-safe: Can be called concurrently from multiple goroutines. +func (c *CredentialsListeners) Get(connID uint64) (auth.CredentialsListener, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + if len(c.listeners) == 0 { + return nil, false + } + listener, ok := c.listeners[connID] + return listener, ok +} + +// Remove removes the credentials listener for a connection. +// +// This is called when a connection is removed from the pool to prevent +// memory leaks. If no listener exists for the connection ID, this is a no-op. +// +// Thread-safe: Can be called concurrently from multiple goroutines. +func (c *CredentialsListeners) Remove(connID uint64) { + c.lock.Lock() + defer c.lock.Unlock() + delete(c.listeners, connID) +} diff --git a/internal/auth/streaming/manager.go b/internal/auth/streaming/manager.go new file mode 100644 index 000000000..f785927ee --- /dev/null +++ b/internal/auth/streaming/manager.go @@ -0,0 +1,137 @@ +package streaming + +import ( + "errors" + "time" + + "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/internal/pool" +) + +// Manager coordinates streaming credentials and re-authentication for a connection pool. +// +// The manager is responsible for: +// - Creating and managing per-connection credentials listeners +// - Providing the pool hook for re-authentication +// - Coordinating between credentials updates and pool operations +// +// When credentials change via a StreamingCredentialsProvider: +// 1. The credentials listener (ConnReAuthCredentialsListener) receives the update +// 2. It calls MarkForReAuth on the manager +// 3. The manager delegates to the pool hook +// 4. The pool hook schedules background re-authentication +// +// The manager maintains a registry of credentials listeners indexed by connection ID, +// allowing listener reuse when connections are reinitialized (e.g., after handoff). +type Manager struct { + // credentialsListeners maps connection ID to credentials listener + credentialsListeners *CredentialsListeners + + // pool is the connection pool being managed + pool pool.Pooler + + // poolHookRef is the re-authentication pool hook + poolHookRef *ReAuthPoolHook +} + +// NewManager creates a new streaming credentials manager. +// +// Parameters: +// - pl: The connection pool to manage +// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication +// +// The manager creates a ReAuthPoolHook sized to match the pool size, ensuring that +// re-auth operations don't exhaust the connection pool. +func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager { + m := &Manager{ + pool: pl, + poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout), + credentialsListeners: NewCredentialsListeners(), + } + m.poolHookRef.manager = m + return m +} + +// PoolHook returns the pool hook for re-authentication. +// +// This hook should be registered with the connection pool to enable +// automatic re-authentication when credentials change. +func (m *Manager) PoolHook() pool.PoolHook { + return m.poolHookRef +} + +// Listener returns or creates a credentials listener for a connection. +// +// This method is called during connection initialization to set up the +// credentials listener. If a listener already exists for the connection ID +// (e.g., after a handoff), it is reused. +// +// Parameters: +// - poolCn: The connection to create/get a listener for +// - reAuth: Function to re-authenticate the connection with new credentials +// - onErr: Function to call when re-authentication fails +// +// Returns: +// - auth.CredentialsListener: The listener to subscribe to the credentials provider +// - error: Non-nil if poolCn is nil +// +// Note: The reAuth and onErr callbacks are captured once when the listener is +// created and reused for the connection's lifetime. They should not change. +// +// Thread-safe: Can be called concurrently during connection initialization. +func (m *Manager) Listener( + poolCn *pool.Conn, + reAuth func(*pool.Conn, auth.Credentials) error, + onErr func(*pool.Conn, error), +) (auth.CredentialsListener, error) { + if poolCn == nil { + return nil, errors.New("poolCn cannot be nil") + } + connID := poolCn.GetID() + // if we reconnect the underlying network connection, the streaming credentials listener will continue to work + // so we can get the old listener from the cache and use it. + // subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op + listener, ok := m.credentialsListeners.Get(connID) + if !ok || listener == nil { + // Create new listener for this connection + // Note: Callbacks (reAuth, onErr) are captured once and reused for the connection's lifetime + newCredListener := &ConnReAuthCredentialsListener{ + conn: poolCn, + reAuth: reAuth, + onErr: onErr, + manager: m, + } + + m.credentialsListeners.Add(connID, newCredListener) + listener = newCredListener + } + return listener, nil +} + +// MarkForReAuth marks a connection for re-authentication. +// +// This method is called by the credentials listener when new credentials are +// received. It delegates to the pool hook to schedule background re-authentication. +// +// Parameters: +// - poolCn: The connection to re-authenticate +// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails +// +// Thread-safe: Called by credentials listeners when credentials change. +func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) { + connID := poolCn.GetID() + m.poolHookRef.MarkForReAuth(connID, reAuthFn) +} + +// RemoveListener removes the credentials listener for a connection. +// +// This method is called by the pool hook's OnRemove to clean up listeners +// when connections are removed from the pool. +// +// Parameters: +// - connID: The connection ID whose listener should be removed +// +// Thread-safe: Called during connection removal. +func (m *Manager) RemoveListener(connID uint64) { + m.credentialsListeners.Remove(connID) +} diff --git a/internal/auth/streaming/manager_test.go b/internal/auth/streaming/manager_test.go new file mode 100644 index 000000000..e4ff813ed --- /dev/null +++ b/internal/auth/streaming/manager_test.go @@ -0,0 +1,101 @@ +package streaming + +import ( + "context" + "testing" + "time" + + "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/internal/pool" +) + +// Test that Listener returns the newly created listener, not nil +func TestManager_Listener_ReturnsNewListener(t *testing.T) { + // Create a mock pool + mockPool := &mockPooler{} + + // Create manager + manager := NewManager(mockPool, time.Second) + + // Create a mock connection + conn := &pool.Conn{} + + // Mock functions + reAuth := func(cn *pool.Conn, creds auth.Credentials) error { + return nil + } + + onErr := func(cn *pool.Conn, err error) { + } + + // Get listener - this should create a new one + listener, err := manager.Listener(conn, reAuth, onErr) + + // Verify no error + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + // Verify listener is not nil (this was the bug!) + if listener == nil { + t.Fatal("Expected listener to be non-nil, but got nil") + } + + // Verify it's the correct type + if _, ok := listener.(*ConnReAuthCredentialsListener); !ok { + t.Fatalf("Expected listener to be *ConnReAuthCredentialsListener, got %T", listener) + } + + // Get the same listener again - should return the existing one + listener2, err := manager.Listener(conn, reAuth, onErr) + if err != nil { + t.Fatalf("Expected no error on second call, got: %v", err) + } + + if listener2 == nil { + t.Fatal("Expected listener2 to be non-nil") + } + + // Should be the same instance + if listener != listener2 { + t.Error("Expected to get the same listener instance on second call") + } +} + +// Test that Listener returns error when conn is nil +func TestManager_Listener_NilConn(t *testing.T) { + mockPool := &mockPooler{} + manager := NewManager(mockPool, time.Second) + + listener, err := manager.Listener(nil, nil, nil) + + if err == nil { + t.Fatal("Expected error when conn is nil, got nil") + } + + if listener != nil { + t.Error("Expected listener to be nil when error occurs") + } + + expectedErr := "poolCn cannot be nil" + if err.Error() != expectedErr { + t.Errorf("Expected error message %q, got %q", expectedErr, err.Error()) + } +} + +// Mock pooler for testing +type mockPooler struct{} + +func (m *mockPooler) NewConn(ctx context.Context) (*pool.Conn, error) { return nil, nil } +func (m *mockPooler) CloseConn(*pool.Conn) error { return nil } +func (m *mockPooler) Get(ctx context.Context) (*pool.Conn, error) { return nil, nil } +func (m *mockPooler) Put(ctx context.Context, conn *pool.Conn) {} +func (m *mockPooler) Remove(ctx context.Context, conn *pool.Conn, reason error) {} +func (m *mockPooler) Len() int { return 0 } +func (m *mockPooler) IdleLen() int { return 0 } +func (m *mockPooler) Stats() *pool.Stats { return &pool.Stats{} } +func (m *mockPooler) Size() int { return 10 } +func (m *mockPooler) AddPoolHook(hook pool.PoolHook) {} +func (m *mockPooler) RemovePoolHook(hook pool.PoolHook) {} +func (m *mockPooler) Close() error { return nil } + diff --git a/internal/auth/streaming/pool_hook.go b/internal/auth/streaming/pool_hook.go new file mode 100644 index 000000000..c135e169c --- /dev/null +++ b/internal/auth/streaming/pool_hook.go @@ -0,0 +1,259 @@ +package streaming + +import ( + "context" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" +) + +// ReAuthPoolHook is a pool hook that manages background re-authentication of connections +// when credentials change via a streaming credentials provider. +// +// The hook uses a semaphore-based worker pool to limit concurrent re-authentication +// operations and prevent pool exhaustion. When credentials change, connections are +// marked for re-authentication and processed asynchronously in the background. +// +// The re-authentication process: +// 1. OnPut: When a connection is returned to the pool, check if it needs re-auth +// 2. If yes, schedule it for background processing (move from shouldReAuth to scheduledReAuth) +// 3. A worker goroutine acquires the connection (waits until it's not in use) +// 4. Executes the re-auth function while holding the connection +// 5. Releases the connection back to the pool +// +// The hook ensures that: +// - Only one re-auth operation runs per connection at a time +// - Connections are not used for commands during re-authentication +// - Re-auth operations timeout if they can't acquire the connection +// - Resources are properly cleaned up on connection removal +type ReAuthPoolHook struct { + // shouldReAuth maps connection ID to re-auth function + // Connections in this map need re-authentication but haven't been scheduled yet + shouldReAuth map[uint64]func(error) + shouldReAuthLock sync.RWMutex + + // workers is a semaphore channel limiting concurrent re-auth operations + // Initialized with poolSize tokens to prevent pool exhaustion + workers chan struct{} + + // reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth + reAuthTimeout time.Duration + + // scheduledReAuth maps connection ID to scheduled status + // Connections in this map have a background worker attempting re-authentication + scheduledReAuth map[uint64]bool + scheduledLock sync.RWMutex + + // manager is a back-reference for cleanup operations + manager *Manager +} + +// NewReAuthPoolHook creates a new re-authentication pool hook. +// +// Parameters: +// - poolSize: Maximum number of concurrent re-auth operations (typically matches pool size) +// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication +// +// The poolSize parameter is used to initialize the worker semaphore, ensuring that +// re-auth operations don't exhaust the connection pool. +func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook { + workers := make(chan struct{}, poolSize) + // Initialize the workers channel with tokens (semaphore pattern) + for i := 0; i < poolSize; i++ { + workers <- struct{}{} + } + + return &ReAuthPoolHook{ + shouldReAuth: make(map[uint64]func(error)), + scheduledReAuth: make(map[uint64]bool), + workers: workers, + reAuthTimeout: reAuthTimeout, + } +} + +// MarkForReAuth marks a connection for re-authentication. +// +// This method is called when credentials change and a connection needs to be +// re-authenticated. The actual re-authentication happens asynchronously when +// the connection is returned to the pool (in OnPut). +// +// Parameters: +// - connID: The connection ID to mark for re-authentication +// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails +// +// Thread-safe: Can be called concurrently from multiple goroutines. +func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) { + r.shouldReAuthLock.Lock() + defer r.shouldReAuthLock.Unlock() + r.shouldReAuth[connID] = reAuthFn +} + +// OnGet is called when a connection is retrieved from the pool. +// +// This hook checks if the connection needs re-authentication or has a scheduled +// re-auth operation. If so, it rejects the connection (returns accept=false), +// causing the pool to try another connection. +// +// Returns: +// - accept: false if connection needs re-auth, true otherwise +// - err: always nil (errors are not used in this hook) +// +// Thread-safe: Called concurrently by multiple goroutines getting connections. +func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) { + connID := conn.GetID() + r.shouldReAuthLock.RLock() + _, shouldReAuth := r.shouldReAuth[connID] + r.shouldReAuthLock.RUnlock() + // This connection was marked for reauth while in the pool, + // reject the connection + if shouldReAuth { + // simply reject the connection, it will be re-authenticated in OnPut + return false, nil + } + r.scheduledLock.RLock() + _, hasScheduled := r.scheduledReAuth[connID] + r.scheduledLock.RUnlock() + // has scheduled reauth, reject the connection + if hasScheduled { + // simply reject the connection, it currently has a reauth scheduled + // and the worker is waiting for slot to execute the reauth + return false, nil + } + return true, nil +} + +// OnPut is called when a connection is returned to the pool. +// +// This hook checks if the connection needs re-authentication. If so, it schedules +// a background goroutine to perform the re-auth asynchronously. The goroutine: +// 1. Waits for a worker slot (semaphore) +// 2. Acquires the connection (waits until not in use) +// 3. Executes the re-auth function +// 4. Releases the connection and worker slot +// +// The connection is always pooled (not removed) since re-auth happens in background. +// +// Returns: +// - shouldPool: always true (connection stays in pool during background re-auth) +// - shouldRemove: always false +// - err: always nil +// +// Thread-safe: Called concurrently by multiple goroutines returning connections. +func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) { + if conn == nil { + // noop + return true, false, nil + } + connID := conn.GetID() + // Check if reauth is needed and get the function with proper locking + r.shouldReAuthLock.RLock() + reAuthFn, ok := r.shouldReAuth[connID] + r.shouldReAuthLock.RUnlock() + + if ok { + // Acquire both locks to atomically move from shouldReAuth to scheduledReAuth + // This prevents race conditions where OnGet might miss the transition + r.shouldReAuthLock.Lock() + r.scheduledLock.Lock() + r.scheduledReAuth[connID] = true + delete(r.shouldReAuth, connID) + r.scheduledLock.Unlock() + r.shouldReAuthLock.Unlock() + go func() { + <-r.workers + // safety first + if conn == nil || (conn != nil && conn.IsClosed()) { + r.workers <- struct{}{} + return + } + defer func() { + if rec := recover(); rec != nil { + // once again - safety first + internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec) + } + r.scheduledLock.Lock() + delete(r.scheduledReAuth, connID) + r.scheduledLock.Unlock() + r.workers <- struct{}{} + }() + + var err error + timeout := time.After(r.reAuthTimeout) + + // Try to acquire the connection + // We need to ensure the connection is both Usable and not Used + // to prevent data races with concurrent operations + const baseDelay = 10 * time.Microsecond + acquired := false + attempt := 0 + for !acquired { + select { + case <-timeout: + // Timeout occurred, cannot acquire connection + err = pool.ErrConnUnusableTimeout + reAuthFn(err) + return + default: + // Try to acquire: set Usable=false, then check Used + if conn.CompareAndSwapUsable(true, false) { + if !conn.IsUsed() { + acquired = true + } else { + // Release Usable and retry with exponential backoff + // todo(ndyakov): think of a better way to do this without the need + // to release the connection, but just wait till it is not used + conn.SetUsable(true) + } + } + if !acquired { + // Exponential backoff: 10, 20, 40, 80... up to 5120 microseconds + delay := baseDelay * time.Duration(1< 0 && attempt < maxRetries-1 { + delay := baseDelay * time.Duration(1<= getAttempts { internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts) @@ -454,17 +470,19 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } // Process connection using the hooks system - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() - if hookManager != nil { - if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil { + acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false) + if err != nil { internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) - // Failed to process connection, discard it _ = p.CloseConn(cn) continue } + if !acceptConn { + internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) + p.Put(ctx, cn) + cn = nil + continue + } } atomic.AddUint32(&p.stats.Hits, 1) @@ -480,14 +498,13 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { } // Process connection using the hooks system - p.hookManagerMu.RLock() - hookManager := p.hookManager - p.hookManagerMu.RUnlock() - if hookManager != nil { - if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil { + acceptConn, err := hookManager.ProcessOnGet(ctx, newcn, true) + // both errors and accept=false mean a hook rejected the connection + // this should not happen with a new connection, but we handle it gracefully + if err != nil || !acceptConn { // Failed to process connection, discard it - internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err) + internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err) _ = p.CloseConn(newcn) return nil, err } @@ -567,9 +584,12 @@ func (p *ConnPool) popIdle() (*Conn, error) { } attempts++ - if cn.IsUsable() { - p.idleConnsLen.Add(-1) - break + if cn.CompareAndSwapUsed(false, true) { + if cn.IsUsable() { + p.idleConnsLen.Add(-1) + break + } + cn.SetUsed(false) } // Connection is not usable, put it back in the pool @@ -664,6 +684,11 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { shouldCloseConn = true } + // if the connection is not going to be closed, mark it as not used + if !shouldCloseConn { + cn.SetUsed(false) + } + p.freeTurn() if shouldCloseConn { @@ -671,7 +696,15 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { } } -func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) { +func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { + p.hookManagerMu.RLock() + hookManager := p.hookManager + p.hookManagerMu.RUnlock() + + if hookManager != nil { + hookManager.ProcessOnRemove(ctx, cn, reason) + } + p.removeConnWithLock(cn) p.freeTurn() @@ -733,6 +766,14 @@ func (p *ConnPool) IdleLen() int { return int(n) } +// Size returns the maximum pool size (capacity). +// +// This is used by the streaming credentials manager to size the re-auth worker pool, +// ensuring that re-auth operations don't exhaust the connection pool. +func (p *ConnPool) Size() int { + return int(p.cfg.PoolSize) +} + func (p *ConnPool) Stats() *Stats { return &Stats{ Hits: atomic.LoadUint32(&p.stats.Hits), diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 136d6f2dd..712d482d8 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -2,8 +2,12 @@ package pool import ( "context" + "time" ) +// SingleConnPool is a pool that always returns the same connection. +// Note: This pool is not thread-safe. +// It is intended to be used by clients that need a single connection. type SingleConnPool struct { pool Pooler cn *Conn @@ -12,6 +16,12 @@ type SingleConnPool struct { var _ Pooler = (*SingleConnPool)(nil) +// NewSingleConnPool creates a new single connection pool. +// The pool will always return the same connection. +// The pool will not: +// - Close the connection +// - Reconnect the connection +// - Track the connection in any way func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool { return &SingleConnPool{ pool: pool, @@ -27,16 +37,30 @@ func (p *SingleConnPool) CloseConn(cn *Conn) error { return p.pool.CloseConn(cn) } -func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { +func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) { if p.stickyErr != nil { return nil, p.stickyErr } + if p.cn == nil { + return nil, ErrClosed + } + p.cn.SetUsed(true) + p.cn.SetUsedAt(time.Now()) return p.cn, nil } -func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {} +func (p *SingleConnPool) Put(_ context.Context, cn *Conn) { + if p.cn == nil { + return + } + if p.cn != cn { + return + } + p.cn.SetUsed(false) +} -func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) { +func (p *SingleConnPool) Remove(_ context.Context, cn *Conn, reason error) { + cn.SetUsed(false) p.cn = nil p.stickyErr = reason } @@ -55,10 +79,13 @@ func (p *SingleConnPool) IdleLen() int { return 0 } +// Size returns the maximum pool size, which is always 1 for SingleConnPool. +func (p *SingleConnPool) Size() int { return 1 } + func (p *SingleConnPool) Stats() *Stats { return &Stats{} } -func (p *SingleConnPool) AddPoolHook(hook PoolHook) {} +func (p *SingleConnPool) AddPoolHook(_ PoolHook) {} -func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {} +func (p *SingleConnPool) RemovePoolHook(_ PoolHook) {} diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index dc4266a4f..22e5a941b 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -196,6 +196,9 @@ func (p *StickyConnPool) IdleLen() int { return len(p.ch) } +// Size returns the maximum pool size, which is always 1 for StickyConnPool. +func (p *StickyConnPool) Size() int { return 1 } + func (p *StickyConnPool) Stats() *Stats { return &Stats{} } diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index ef1ed5f9b..6aa6dc091 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -497,9 +497,14 @@ func TestDialerRetryConfiguration(t *testing.T) { } // Should have attempted 5 times (default DialerRetries = 5) + // Note: There may be one additional attempt from tryDial() goroutine + // which is launched when dialErrorsNum reaches PoolSize finalAttempts := atomic.LoadInt64(&attempts) - if finalAttempts != 5 { - t.Errorf("Expected 5 dial attempts (default), got %d", finalAttempts) + if finalAttempts < 5 { + t.Errorf("Expected at least 5 dial attempts (default), got %d", finalAttempts) + } + if finalAttempts > 6 { + t.Errorf("Expected around 5 dial attempts, got %d (too many)", finalAttempts) } }) } diff --git a/internal/pool/pubsub.go b/internal/pool/pubsub.go index 73ee4b3ec..ed87d1bbc 100644 --- a/internal/pool/pubsub.go +++ b/internal/pool/pubsub.go @@ -24,6 +24,8 @@ type PubSubPool struct { stats PubSubStats } +// PubSubPool implements a pool for PubSub connections. +// It intentionally does not implement the Pooler interface func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool { return &PubSubPool{ opt: opt, diff --git a/maintnotifications/handoff_worker.go b/maintnotifications/handoff_worker.go index 61dc1e171..22df2c800 100644 --- a/maintnotifications/handoff_worker.go +++ b/maintnotifications/handoff_worker.go @@ -378,8 +378,12 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c } // performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration) -func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, conn *pool.Conn, newEndpoint string, connID uint64) (shouldRetry bool, err error) { - +func (hwm *handoffWorkerManager) performHandoffInternal( + ctx context.Context, + conn *pool.Conn, + newEndpoint string, + connID uint64, +) (shouldRetry bool, err error) { retries := conn.IncrementAndGetHandoffRetries(1) internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String())) maxRetries := 3 // Default fallback @@ -438,9 +442,14 @@ func (hwm *handoffWorkerManager) performHandoffInternal(ctx context.Context, con } }() + // Clear handoff state will: + // - set the connection as usable again + // - clear the handoff state (shouldHandoff, endpoint, seqID) + // - reset the handoff retries to 0 conn.ClearHandoffState() internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint)) + // successfully completed the handoff, no retry needed and no error return false, nil } @@ -472,7 +481,10 @@ func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, reque internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err)) } } else { - conn.Close() + err := conn.Close() // Close the connection if no pool provided + if err != nil { + internal.Logger.Printf(ctx, "redis: failed to close connection: %v", err) + } if internal.LogLevel.WarnOrAbove() { internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err)) } diff --git a/maintnotifications/pool_hook.go b/maintnotifications/pool_hook.go index 695c3a648..9fd24b4a7 100644 --- a/maintnotifications/pool_hook.go +++ b/maintnotifications/pool_hook.go @@ -116,22 +116,22 @@ func (ph *PoolHook) ResetCircuitBreakers() { } // OnGet is called when a connection is retrieved from the pool -func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error { +func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) { // NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is // in a handoff state at the moment. // Check if connection is usable (not in a handoff state) // Should not happen since the pool will not return a connection that is not usable. if !conn.IsUsable() { - return ErrConnectionMarkedForHandoff + return false, ErrConnectionMarkedForHandoff } // Check if connection is marked for handoff, which means it will be queued for handoff on put. if conn.ShouldHandoff() { - return ErrConnectionMarkedForHandoff + return false, ErrConnectionMarkedForHandoff } - return nil + return true, nil } // OnPut is called when a connection is returned to the pool @@ -174,6 +174,10 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool return true, false, nil } +func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) { + // Not used +} + // Shutdown gracefully shuts down the processor, waiting for workers to complete func (ph *PoolHook) Shutdown(ctx context.Context) error { return ph.workerManager.shutdownWorkers(ctx) diff --git a/maintnotifications/pool_hook_test.go b/maintnotifications/pool_hook_test.go index c689179d7..51e73c3ec 100644 --- a/maintnotifications/pool_hook_test.go +++ b/maintnotifications/pool_hook_test.go @@ -92,6 +92,10 @@ func (mp *mockPool) Stats() *pool.Stats { return &pool.Stats{} } +func (mp *mockPool) Size() int { + return 0 +} + func (mp *mockPool) AddPoolHook(hook pool.PoolHook) { // Mock implementation - do nothing } @@ -356,10 +360,13 @@ func TestConnectionHook(t *testing.T) { conn := createMockPoolConnection() ctx := context.Background() - err := processor.OnGet(ctx, conn, false) + acceptCon, err := processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("OnGet should not error for normal connection: %v", err) } + if !acceptCon { + t.Error("Connection should be accepted for normal connection") + } }) t.Run("OnGetWithPendingHandoff", func(t *testing.T) { @@ -381,10 +388,13 @@ func TestConnectionHook(t *testing.T) { conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) ctx := context.Background() - err := processor.OnGet(ctx, conn, false) + acceptCon, err := processor.OnGet(ctx, conn, false) if err != ErrConnectionMarkedForHandoff { t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) } + if acceptCon { + t.Error("Connection should not be accepted when marked for handoff") + } // Clean up processor.GetPendingMap().Delete(conn) @@ -412,10 +422,13 @@ func TestConnectionHook(t *testing.T) { // Test OnGet with pending handoff ctx := context.Background() - err := processor.OnGet(ctx, conn, false) + acceptCon, err := processor.OnGet(ctx, conn, false) if err != ErrConnectionMarkedForHandoff { t.Error("Should return ErrConnectionMarkedForHandoff for pending connection") } + if acceptCon { + t.Error("Should not accept connection with pending handoff") + } // Test removing from pending map and clearing handoff state processor.GetPendingMap().Delete(conn) @@ -428,10 +441,13 @@ func TestConnectionHook(t *testing.T) { conn.SetUsable(true) // Make connection usable again // Test OnGet without pending handoff - err = processor.OnGet(ctx, conn, false) + acceptCon, err = processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("Should not return error for non-pending connection: %v", err) } + if !acceptCon { + t.Error("Should accept connection without pending handoff") + } }) t.Run("EventDrivenQueueOptimization", func(t *testing.T) { @@ -624,11 +640,15 @@ func TestConnectionHook(t *testing.T) { } // OnGet should succeed for usable connection - err := processor.OnGet(ctx, conn, false) + acceptConn, err := processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("OnGet should succeed for usable connection: %v", err) } + if !acceptConn { + t.Error("Connection should be accepted when usable") + } + // Mark connection for handoff if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { t.Fatalf("Failed to mark connection for handoff: %v", err) @@ -648,13 +668,17 @@ func TestConnectionHook(t *testing.T) { } // OnGet should fail for connection marked for handoff - err = processor.OnGet(ctx, conn, false) + acceptConn, err = processor.OnGet(ctx, conn, false) if err == nil { t.Error("OnGet should fail for connection marked for handoff") } + if err != ErrConnectionMarkedForHandoff { t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) } + if acceptConn { + t.Error("Connection should not be accepted when marked for handoff") + } // Process the connection to trigger handoff shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) @@ -674,11 +698,15 @@ func TestConnectionHook(t *testing.T) { } // OnGet should succeed again - err = processor.OnGet(ctx, conn, false) + acceptConn, err = processor.OnGet(ctx, conn, false) if err != nil { t.Errorf("OnGet should succeed after handoff completion: %v", err) } + if !acceptConn { + t.Error("Connection should be accepted after handoff completion") + } + t.Logf("Usable flag behavior test completed successfully") }) diff --git a/pubsub.go b/pubsub.go index 5e02b0bd2..959a5c45b 100644 --- a/pubsub.go +++ b/pubsub.go @@ -465,7 +465,6 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int } // Don't hold the lock to allow subscriptions and pings. - cn, err := c.connWithLock(ctx) if err != nil { return nil, err diff --git a/redis.go b/redis.go index b308263e2..dcd7b59a7 100644 --- a/redis.go +++ b/redis.go @@ -11,6 +11,7 @@ import ( "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/auth/streaming" "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" @@ -224,6 +225,9 @@ type baseClient struct { // Maintenance notifications manager maintNotificationsManager *maintnotifications.Manager maintNotificationsManagerLock sync.RWMutex + + // streamingCredentialsManager is used to manage streaming credentials + streamingCredentialsManager *streaming.Manager } func (c *baseClient) clone() *baseClient { @@ -232,11 +236,12 @@ func (c *baseClient) clone() *baseClient { c.maintNotificationsManagerLock.RUnlock() clone := &baseClient{ - opt: c.opt, - connPool: c.connPool, - onClose: c.onClose, - pushProcessor: c.pushProcessor, - maintNotificationsManager: maintNotificationsManager, + opt: c.opt, + connPool: c.connPool, + onClose: c.onClose, + pushProcessor: c.pushProcessor, + maintNotificationsManager: maintNotificationsManager, + streamingCredentialsManager: c.streamingCredentialsManager, } return clone } @@ -296,32 +301,30 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return cn, nil } -func (c *baseClient) newReAuthCredentialsListener(poolCn *pool.Conn) auth.CredentialsListener { - return auth.NewReAuthCredentialsListener( - c.reAuthConnection(poolCn), - c.onAuthenticationErr(poolCn), - ) -} - -func (c *baseClient) reAuthConnection(poolCn *pool.Conn) func(credentials auth.Credentials) error { - return func(credentials auth.Credentials) error { +func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error { + return func(poolCn *pool.Conn, credentials auth.Credentials) error { var err error username, password := credentials.BasicAuth() + + // Use background context - timeout is handled by ReadTimeout in WithReader/WithWriter ctx := context.Background() + connPool := pool.NewSingleConnPool(c.connPool, poolCn) - // hooksMixin are intentionally empty here - cn := newConn(c.opt, connPool, nil) + + // Pass hooks so that reauth commands are recorded/traced + cn := newConn(c.opt, connPool, &c.hooksMixin) if username != "" { err = cn.AuthACL(ctx, username, password).Err() } else { err = cn.Auth(ctx, password).Err() } + return err } } -func (c *baseClient) onAuthenticationErr(poolCn *pool.Conn) func(err error) { - return func(err error) { +func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) { + return func(poolCn *pool.Conn, err error) { if err != nil { if isBadConn(err, false, c.opt.Addr) { // Close the connection to force a reconnection. @@ -372,13 +375,24 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { username, password := "", "" if c.opt.StreamingCredentialsProvider != nil { + credListener, err := c.streamingCredentialsManager.Listener( + cn, + c.reAuthConnection(), + c.onAuthenticationErr(), + ) + if err != nil { + return fmt.Errorf("failed to create credentials listener: %w", err) + } + credentials, unsubscribeFromCredentialsProvider, err := c.opt.StreamingCredentialsProvider. - Subscribe(c.newReAuthCredentialsListener(cn)) + Subscribe(credListener) if err != nil { return fmt.Errorf("failed to subscribe to streaming credentials: %w", err) } + c.onClose = c.wrappedOnClose(unsubscribeFromCredentialsProvider) cn.SetOnClose(unsubscribeFromCredentialsProvider) + username, password = credentials.BasicAuth() } else if c.opt.CredentialsProviderContext != nil { username, password, err = c.opt.CredentialsProviderContext(ctx) @@ -496,7 +510,10 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } } + // mark the connection as usable and inited + // once returned to the pool as idle, this connection can be used by other clients cn.SetUsable(true) + cn.SetUsed(false) cn.Inited.Store(true) // Set the connection initialization function for potential reconnections @@ -952,6 +969,11 @@ func NewClient(opt *Options) *Client { panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) } + if opt.StreamingCredentialsProvider != nil { + c.streamingCredentialsManager = streaming.NewManager(c.connPool, c.opt.PoolTimeout) + c.connPool.AddPoolHook(c.streamingCredentialsManager.PoolHook()) + } + // Initialize maintnotifications first if enabled and protocol is RESP3 if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 { err := c.enableMaintNotificationsUpgrades() diff --git a/redis_test.go b/redis_test.go index 27b69ed14..0906d420b 100644 --- a/redis_test.go +++ b/redis_test.go @@ -854,24 +854,34 @@ var _ = Describe("Credentials Provider Priority", func() { credentials: initialCreds, updates: updatesChan, }, + PoolSize: 1, // Force single connection to ensure reauth is tested } client = redis.NewClient(opt) client.AddHook(recorder.Hook()) // wrongpass Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) + time.Sleep(10 * time.Millisecond) Expect(recorder.Contains("AUTH initial_user")).To(BeTrue()) // Update credentials opt.StreamingCredentialsProvider.(*mockStreamingProvider).updates <- updatedCreds - // wrongpass - Expect(client.Ping(context.Background()).Err()).To(HaveOccurred()) - Expect(recorder.Contains("AUTH updated_user")).To(BeTrue()) + + // Wait for reauth to complete and verify updated credentials are used + // We need to keep trying Ping until we see the updated AUTH command + // because the reauth happens asynchronously + Eventually(func() bool { + // wrongpass + _ = client.Ping(context.Background()).Err() + return recorder.Contains("AUTH updated_user") + }, "1s", "50ms").Should(BeTrue()) + close(updatesChan) }) }) type mockStreamingProvider struct { + mu sync.RWMutex credentials auth.Credentials err error updates chan auth.Credentials @@ -882,21 +892,50 @@ func (m *mockStreamingProvider) Subscribe(listener auth.CredentialsListener) (au return nil, nil, m.err } + if listener == nil { + return nil, nil, errors.New("listener cannot be nil") + } + + // Create a done channel to stop the goroutine + done := make(chan struct{}) + // Start goroutine to handle updates go func() { - for creds := range m.updates { - m.credentials = creds - listener.OnNext(creds) + defer func() { + if r := recover(); r != nil { + // this is just a mock: + // allow panics to be caught without crashing + } + }() + + for { + select { + case <-done: + return + case creds, ok := <-m.updates: + if !ok { + return + } + m.mu.Lock() + m.credentials = creds + m.mu.Unlock() + listener.OnNext(creds) + } } }() - return m.credentials, func() (err error) { + m.mu.RLock() + currentCreds := m.credentials + m.mu.RUnlock() + + return currentCreds, func() (err error) { defer func() { if r := recover(); r != nil { // this is just a mock: // allow multiple closes from multiple listeners } }() + close(done) return }, nil } diff --git a/sentinel_test.go b/sentinel_test.go index bfeb28161..f332822f5 100644 --- a/sentinel_test.go +++ b/sentinel_test.go @@ -410,7 +410,9 @@ var _ = Describe("SentinelAclAuth", func() { }) }) -func TestParseFailoverURL(t *testing.T) { +// renaming from TestParseFailoverURL to TestParseSentinelURL +// to be easier to find Failed tests in the test output +func TestParseSentinelURL(t *testing.T) { cases := []struct { url string o *redis.FailoverOptions