diff --git a/config.yaml b/config.yaml index 283d2890..5bfe96e9 100644 --- a/config.yaml +++ b/config.yaml @@ -256,16 +256,36 @@ synchronization: ## Agent publishes to channels: "optimizely-sync-{sdk_key}" ## For external Redis clients: Subscribe "optimizely-sync-{sdk_key}" or PSubscribe "optimizely-sync-*" ## Note: Channel configuration parsing is a known bug - planned for future release + + ## Redis Streams configuration (when using Redis Streams for notifications) + ## batch_size: number of messages to batch before sending (default: 10) + # batch_size: 10 + ## flush_interval: maximum time to wait before sending a partial batch (default: 5s) + # flush_interval: 5s + ## max_retries: maximum number of retry attempts for failed operations (default: 3) + # max_retries: 3 + ## retry_delay: initial delay between retry attempts (default: 100ms) + # retry_delay: 100ms + ## max_retry_delay: maximum delay between retry attempts with exponential backoff (default: 5s) + # max_retry_delay: 5s + ## connection_timeout: timeout for Redis connections (default: 10s) + # connection_timeout: 10s ## if notification synchronization is enabled, then the active notification event-stream API ## will get the notifications from available replicas notification: enable: false + ## Use "redis" for fire-and-forget pub/sub (existing behavior) + ## Use "redis-streams" for persistent message delivery with retries and acknowledgment default: "redis" + # default: "redis-streams" # Uncomment to enable Redis Streams ## if datafile synchronization is enabled, then for each webhook API call ## the datafile will be sent to all available replicas to achieve better eventual consistency datafile: enable: false + ## Use "redis" for fire-and-forget pub/sub (existing behavior) + ## Use "redis-streams" for persistent message delivery with retries and acknowledgment default: "redis" + # default: "redis-streams" # Uncomment to enable Redis Streams ## ## cmab: Contextual Multi-Armed Bandit configuration diff --git a/pkg/syncer/pubsub.go b/pkg/syncer/pubsub.go index 6436c03e..9e80293b 100644 --- a/pkg/syncer/pubsub.go +++ b/pkg/syncer/pubsub.go @@ -20,6 +20,7 @@ package syncer import ( "context" "errors" + "time" "github.com/optimizely/agent/config" "github.com/optimizely/agent/pkg/syncer/pubsub" @@ -28,8 +29,10 @@ import ( const ( // PubSubDefaultChan will be used as default pubsub channel name PubSubDefaultChan = "optimizely-sync" - // PubSubRedis is the name of pubsub type of Redis + // PubSubRedis is the name of pubsub type of Redis (fire-and-forget) PubSubRedis = "redis" + // PubSubRedisStreams is the name of pubsub type of Redis Streams (persistent) + PubSubRedisStreams = "redis-streams" ) type SycnFeatureFlag string @@ -48,12 +51,16 @@ func newPubSub(conf config.SyncConfig, featureFlag SycnFeatureFlag) (PubSub, err if featureFlag == SyncFeatureFlagNotificaiton { if conf.Notification.Default == PubSubRedis { return getPubSubRedis(conf) + } else if conf.Notification.Default == PubSubRedisStreams { + return getPubSubRedisStreams(conf) } else { return nil, errors.New("pubsub type not supported") } } else if featureFlag == SycnFeatureFlagDatafile { if conf.Datafile.Default == PubSubRedis { return getPubSubRedis(conf) + } else if conf.Datafile.Default == PubSubRedisStreams { + return getPubSubRedisStreams(conf) } else { return nil, errors.New("pubsub type not supported") } @@ -99,9 +106,92 @@ func getPubSubRedis(conf config.SyncConfig) (PubSub, error) { return nil, errors.New("pubsub redis database not valid, database must be int") } + // Return original Redis pub/sub implementation (fire-and-forget) return &pubsub.Redis{ Host: host, Password: password, Database: database, }, nil } + +func getPubSubRedisStreams(conf config.SyncConfig) (PubSub, error) { + pubsubConf, found := conf.Pubsub[PubSubRedis] + if !found { + return nil, errors.New("pubsub redis config not found") + } + + redisConf, ok := pubsubConf.(map[string]interface{}) + if !ok { + return nil, errors.New("pubsub redis config not valid") + } + + hostVal, found := redisConf["host"] + if !found { + return nil, errors.New("pubsub redis host not found") + } + host, ok := hostVal.(string) + if !ok { + return nil, errors.New("pubsub redis host not valid, host must be string") + } + + passwordVal, found := redisConf["password"] + if !found { + return nil, errors.New("pubsub redis password not found") + } + password, ok := passwordVal.(string) + if !ok { + return nil, errors.New("pubsub redis password not valid, password must be string") + } + + databaseVal, found := redisConf["database"] + if !found { + return nil, errors.New("pubsub redis database not found") + } + database, ok := databaseVal.(int) + if !ok { + return nil, errors.New("pubsub redis database not valid, database must be int") + } + + // Parse optional Redis Streams configuration parameters + batchSize := getIntFromConfig(redisConf, "batch_size", 10) + flushInterval := getDurationFromConfig(redisConf, "flush_interval", 5*time.Second) + maxRetries := getIntFromConfig(redisConf, "max_retries", 3) + retryDelay := getDurationFromConfig(redisConf, "retry_delay", 100*time.Millisecond) + maxRetryDelay := getDurationFromConfig(redisConf, "max_retry_delay", 5*time.Second) + connTimeout := getDurationFromConfig(redisConf, "connection_timeout", 10*time.Second) + + // Return Redis Streams implementation with configuration + return &pubsub.RedisStreams{ + Host: host, + Password: password, + Database: database, + BatchSize: batchSize, + FlushInterval: flushInterval, + MaxRetries: maxRetries, + RetryDelay: retryDelay, + MaxRetryDelay: maxRetryDelay, + ConnTimeout: connTimeout, + }, nil +} + +// getIntFromConfig safely extracts an integer value from config map with default fallback +func getIntFromConfig(config map[string]interface{}, key string, defaultValue int) int { + if val, found := config[key]; found { + if intVal, ok := val.(int); ok { + return intVal + } + } + return defaultValue +} + +// getDurationFromConfig safely extracts a duration value from config map with default fallback +func getDurationFromConfig(config map[string]interface{}, key string, defaultValue time.Duration) time.Duration { + if val, found := config[key]; found { + if strVal, ok := val.(string); ok { + if duration, err := time.ParseDuration(strVal); err == nil { + return duration + } + } + } + return defaultValue +} diff --git a/pkg/syncer/pubsub/redis_streams.go b/pkg/syncer/pubsub/redis_streams.go new file mode 100644 index 00000000..c7281aa0 --- /dev/null +++ b/pkg/syncer/pubsub/redis_streams.go @@ -0,0 +1,504 @@ +/**************************************************************************** + * Copyright 2025 Optimizely, Inc. and contributors * + * * + * Licensed under the Apache License, Version 2.0 (the "License"); * + * you may not use this file except in compliance with the License. * + * You may obtain a copy of the License at * + * * + * http://www.apache.org/licenses/LICENSE-2.0 * + * * + * Unless required by applicable law or agreed to in writing, software * + * distributed under the License is distributed on an "AS IS" BASIS, * + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * + * See the License for the specific language governing permissions and * + * limitations under the License. * + ***************************************************************************/ + +// Package pubsub provides pubsub functionality for the agent syncer +package pubsub + +import ( + "context" + "encoding/json" + "fmt" + "math" + "strings" + "time" + + "github.com/go-redis/redis/v8" + "github.com/rs/zerolog/log" + + "github.com/optimizely/agent/pkg/metrics" +) + +// RedisStreams implements persistent message delivery using Redis Streams +type RedisStreams struct { + Host string + Password string + Database int + // Stream configuration + MaxLen int64 + ConsumerGroup string + ConsumerName string + // Batching configuration + BatchSize int + FlushInterval time.Duration + // Retry configuration + MaxRetries int + RetryDelay time.Duration + MaxRetryDelay time.Duration + // Connection timeout + ConnTimeout time.Duration + // Metrics registry + metricsRegistry *metrics.Registry +} + +func (r *RedisStreams) Publish(ctx context.Context, channel string, message interface{}) error { + streamName := r.getStreamName(channel) + + // Convert message to string for consistent handling + var messageStr string + switch v := message.(type) { + case []byte: + messageStr = string(v) + case string: + messageStr = v + default: + // For other types, marshal to JSON + jsonBytes, err := json.Marshal(v) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + messageStr = string(jsonBytes) + } + + // Add message to stream with automatic ID generation + args := &redis.XAddArgs{ + Stream: streamName, + Values: map[string]interface{}{ + "data": messageStr, + "timestamp": time.Now().Unix(), + }, + } + + // Apply max length trimming if configured + if r.MaxLen > 0 { + args.MaxLen = r.MaxLen + args.Approx = true // Use approximate trimming for better performance + } + + return r.executeWithRetry(ctx, func(client *redis.Client) error { + return client.XAdd(ctx, args).Err() + }) +} + +func (r *RedisStreams) Subscribe(ctx context.Context, channel string) (chan string, error) { + streamName := r.getStreamName(channel) + consumerGroup := r.getConsumerGroup() + consumerName := r.getConsumerName() + + ch := make(chan string) + + go func() { + defer close(ch) + + batchSize := r.getBatchSize() + flushTicker := time.NewTicker(r.getFlushInterval()) + defer flushTicker.Stop() + + var batch []string + var client *redis.Client + var lastReconnect time.Time + reconnectDelay := 1 * time.Second + maxReconnectDelay := 30 * time.Second + + // Initialize connection + client = r.createClient() + defer client.Close() + + // Create consumer group with retry + if err := r.createConsumerGroupWithRetry(ctx, client, streamName, consumerGroup); err != nil { + log.Error().Err(err).Str("stream", streamName).Str("group", consumerGroup).Msg("Failed to create consumer group") + return + } + + for { + select { + case <-ctx.Done(): + // Send any remaining batch before closing + if len(batch) > 0 { + r.sendBatch(ch, batch, ctx) + } + return + case <-flushTicker.C: + // Flush interval reached - send current batch + if len(batch) > 0 { + r.incrementCounter("batch.flush_interval") + r.sendBatch(ch, batch, ctx) + batch = nil + } + default: + // Read messages from the stream using consumer group + streams, err := client.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: consumerGroup, + Consumer: consumerName, + Streams: []string{streamName, ">"}, + Count: int64(batchSize - len(batch)), // Read up to remaining batch size + Block: 100 * time.Millisecond, // Short block to allow flush checking + }).Result() + + if err != nil { + if err == redis.Nil { + continue // No messages, continue polling + } + + // Handle connection errors with exponential backoff reconnection + if r.isConnectionError(err) { + r.incrementCounter("connection.error") + log.Warn().Err(err).Msg("Redis connection error, attempting reconnection") + + // Apply exponential backoff for reconnection + if time.Since(lastReconnect) > reconnectDelay { + r.incrementCounter("connection.reconnect_attempt") + client.Close() + client = r.createClient() + lastReconnect = time.Now() + + // Recreate consumer group after reconnection + if groupErr := r.createConsumerGroupWithRetry(ctx, client, streamName, consumerGroup); groupErr != nil { + r.incrementCounter("connection.group_recreate_error") + log.Error().Err(groupErr).Msg("Failed to recreate consumer group after reconnection") + } else { + r.incrementCounter("connection.reconnect_success") + } + + // Increase reconnect delay with exponential backoff + reconnectDelay = time.Duration(math.Min(float64(reconnectDelay*2), float64(maxReconnectDelay))) + } else { + // Wait before next retry + time.Sleep(100 * time.Millisecond) + } + } else { + // Log other errors but continue processing + r.incrementCounter("read.error") + log.Debug().Err(err).Msg("Redis streams read error") + } + continue + } + + // Reset reconnect delay on successful read + reconnectDelay = 1 * time.Second + + // Process messages from streams + messageCount := 0 + for _, stream := range streams { + for _, message := range stream.Messages { + // Extract the data field + if data, ok := message.Values["data"].(string); ok { + batch = append(batch, data) + messageCount++ + + // Acknowledge the message with retry + if ackErr := r.acknowledgeMessage(ctx, client, streamName, consumerGroup, message.ID); ackErr != nil { + log.Warn().Err(ackErr).Str("messageID", message.ID).Msg("Failed to acknowledge message") + } + + // Send batch if it's full + if len(batch) >= batchSize { + r.incrementCounter("batch.sent") + r.sendBatch(ch, batch, ctx) + batch = nil + // Continue processing more messages + } + } + } + } + + // Track successful message reads + if messageCount > 0 { + r.incrementCounter("messages.read") + } + } + } + }() + + return ch, nil +} + +// Helper method to send batch to channel +func (r *RedisStreams) sendBatch(ch chan string, batch []string, ctx context.Context) { + for _, msg := range batch { + select { + case ch <- msg: + // Message sent successfully + case <-ctx.Done(): + return + } + } +} + +// Helper methods +func (r *RedisStreams) getStreamName(channel string) string { + return fmt.Sprintf("stream:%s", channel) +} + +func (r *RedisStreams) getConsumerGroup() string { + if r.ConsumerGroup == "" { + return "notifications" + } + return r.ConsumerGroup +} + +func (r *RedisStreams) getConsumerName() string { + if r.ConsumerName == "" { + return fmt.Sprintf("consumer-%d", time.Now().UnixNano()) + } + return r.ConsumerName +} + +func (r *RedisStreams) getBatchSize() int { + if r.BatchSize <= 0 { + return 10 // Default batch size + } + return r.BatchSize +} + +func (r *RedisStreams) getFlushInterval() time.Duration { + if r.FlushInterval <= 0 { + return 5 * time.Second // Default flush interval + } + return r.FlushInterval +} + +func (r *RedisStreams) getMaxRetries() int { + if r.MaxRetries <= 0 { + return 3 // Default max retries + } + return r.MaxRetries +} + +func (r *RedisStreams) getRetryDelay() time.Duration { + if r.RetryDelay <= 0 { + return 100 * time.Millisecond // Default retry delay + } + return r.RetryDelay +} + +func (r *RedisStreams) getMaxRetryDelay() time.Duration { + if r.MaxRetryDelay <= 0 { + return 5 * time.Second // Default max retry delay + } + return r.MaxRetryDelay +} + +func (r *RedisStreams) getConnTimeout() time.Duration { + if r.ConnTimeout <= 0 { + return 10 * time.Second // Default connection timeout + } + return r.ConnTimeout +} + +// createClient creates a new Redis client with configured timeouts +func (r *RedisStreams) createClient() *redis.Client { + return redis.NewClient(&redis.Options{ + Addr: r.Host, + Password: r.Password, + DB: r.Database, + DialTimeout: r.getConnTimeout(), + ReadTimeout: r.getConnTimeout(), + WriteTimeout: r.getConnTimeout(), + PoolTimeout: r.getConnTimeout(), + }) +} + +// executeWithRetry executes a Redis operation with retry logic +func (r *RedisStreams) executeWithRetry(ctx context.Context, operation func(client *redis.Client) error) error { + start := time.Now() + maxRetries := r.getMaxRetries() + retryDelay := r.getRetryDelay() + maxRetryDelay := r.getMaxRetryDelay() + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + client := r.createClient() + err := operation(client) + client.Close() + + if err == nil { + // Record successful operation metrics + r.incrementCounter("operations.success") + r.recordTimer("operations.duration", time.Since(start).Seconds()) + if attempt > 0 { + r.incrementCounter("retries.success") + } + return nil // Success + } + + lastErr = err + r.incrementCounter("operations.error") + + // Don't retry on non-recoverable errors + if !r.isRetryableError(err) { + r.incrementCounter("errors.non_retryable") + return fmt.Errorf("non-retryable error: %w", err) + } + + // Don't sleep after the last attempt + if attempt < maxRetries { + r.incrementCounter("retries.attempt") + // Calculate delay with exponential backoff + delay := time.Duration(math.Min(float64(retryDelay)*math.Pow(2, float64(attempt)), float64(maxRetryDelay))) + + select { + case <-ctx.Done(): + r.incrementCounter("operations.canceled") + return ctx.Err() + case <-time.After(delay): + // Continue to next retry + } + } + } + + r.incrementCounter("retries.exhausted") + return fmt.Errorf("operation failed after %d retries: %w", maxRetries, lastErr) +} + +// createConsumerGroupWithRetry creates a consumer group with retry logic +func (r *RedisStreams) createConsumerGroupWithRetry(ctx context.Context, _ *redis.Client, streamName, consumerGroup string) error { + return r.executeWithRetry(ctx, func(retryClient *redis.Client) error { + _, err := retryClient.XGroupCreateMkStream(ctx, streamName, consumerGroup, "$").Result() + if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" { + return fmt.Errorf("failed to create consumer group: %w", err) + } + return nil + }) +} + +// acknowledgeMessage acknowledges a message with retry logic +func (r *RedisStreams) acknowledgeMessage(ctx context.Context, client *redis.Client, streamName, consumerGroup, messageID string) error { + maxRetries := 2 // Fewer retries for ACK operations + retryDelay := 50 * time.Millisecond + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + err := client.XAck(ctx, streamName, consumerGroup, messageID).Err() + if err == nil { + r.incrementCounter("ack.success") + if attempt > 0 { + r.incrementCounter("ack.retry_success") + } + return nil // Success + } + + lastErr = err + r.incrementCounter("ack.error") + + // Don't retry on non-recoverable errors + if !r.isRetryableError(err) { + r.incrementCounter("ack.non_retryable_error") + return fmt.Errorf("non-retryable ACK error: %w", err) + } + + // Don't sleep after the last attempt + if attempt < maxRetries { + r.incrementCounter("ack.retry_attempt") + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(retryDelay): + // Continue to next retry + } + } + } + + r.incrementCounter("ack.retry_exhausted") + return fmt.Errorf("ACK failed after %d retries: %w", maxRetries, lastErr) +} + +// isRetryableError determines if an error is retryable +func (r *RedisStreams) isRetryableError(err error) bool { + if err == nil { + return false + } + + errStr := err.Error() + + // Network/connection errors that are retryable + retryableErrors := []string{ + "connection refused", + "connection reset", + "timeout", + "network is unreachable", + "broken pipe", + "eof", + "i/o timeout", + "connection pool exhausted", + "context deadline exceeded", + "context canceled", // Handle graceful shutdowns + "no such host", // DNS lookup failures + } + + for _, retryable := range retryableErrors { + if strings.Contains(strings.ToLower(errStr), retryable) { + return true + } + } + + // Redis-specific retryable errors + if strings.Contains(errStr, "LOADING") || // Redis is loading data + strings.Contains(errStr, "READONLY") || // Redis is in read-only mode + strings.Contains(errStr, "CLUSTERDOWN") { // Redis cluster is down + return true + } + + return false +} + +// isConnectionError determines if an error is a connection error +func (r *RedisStreams) isConnectionError(err error) bool { + if err == nil { + return false + } + + errStr := err.Error() + + connectionErrors := []string{ + "connection refused", + "connection reset", + "network is unreachable", + "broken pipe", + "eof", + "connection pool exhausted", + } + + for _, connErr := range connectionErrors { + if strings.Contains(strings.ToLower(errStr), connErr) { + return true + } + } + + return false +} + +// SetMetricsRegistry sets the metrics registry for tracking statistics +func (r *RedisStreams) SetMetricsRegistry(registry *metrics.Registry) { + r.metricsRegistry = registry +} + +// incrementCounter safely increments a metrics counter if registry is available +func (r *RedisStreams) incrementCounter(key string) { + if r.metricsRegistry != nil { + if counter := r.metricsRegistry.GetCounter("redis_streams." + key); counter != nil { + counter.Add(1) + } + } +} + +// recordTimer safely records a timer metric if registry is available +func (r *RedisStreams) recordTimer(key string, duration float64) { + if r.metricsRegistry != nil { + if timer := r.metricsRegistry.NewTimer("redis_streams." + key); timer != nil { + timer.Update(duration) + } + } +} diff --git a/pkg/syncer/pubsub/redis_streams_error_test.go b/pkg/syncer/pubsub/redis_streams_error_test.go new file mode 100644 index 00000000..6b48b509 --- /dev/null +++ b/pkg/syncer/pubsub/redis_streams_error_test.go @@ -0,0 +1,475 @@ +/**************************************************************************** + * Copyright 2025 Optimizely, Inc. and contributors * + * * + * Licensed under the Apache License, Version 2.0 (the "License"); * + * you may not use this file except in compliance with the License. * + * You may obtain a copy of the License at * + * * + * http://www.apache.org/licenses/LICENSE-2.0 * + * * + * Unless required by applicable law or agreed to in writing, software * + * distributed under the License is distributed on an "AS IS" BASIS, * + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * + * See the License for the specific language governing permissions and * + * limitations under the License. * + ***************************************************************************/ + +package pubsub + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/optimizely/agent/pkg/metrics" +) + +func setupRedisStreamsWithRetry() *RedisStreams { + return &RedisStreams{ + Host: "localhost:6379", + Password: "", + Database: 0, + MaxLen: 1000, + ConsumerGroup: "test-group", + ConsumerName: "test-consumer", + BatchSize: 10, + FlushInterval: 5 * time.Second, + MaxRetries: 3, + RetryDelay: 50 * time.Millisecond, + MaxRetryDelay: 1 * time.Second, + ConnTimeout: 5 * time.Second, + // Don't set metricsRegistry by default to avoid conflicts + metricsRegistry: nil, + } +} + +func TestRedisStreams_RetryConfiguration_Defaults(t *testing.T) { + rs := &RedisStreams{} + + assert.Equal(t, 3, rs.getMaxRetries()) + assert.Equal(t, 100*time.Millisecond, rs.getRetryDelay()) + assert.Equal(t, 5*time.Second, rs.getMaxRetryDelay()) + assert.Equal(t, 10*time.Second, rs.getConnTimeout()) +} + +func TestRedisStreams_RetryConfiguration_Custom(t *testing.T) { + rs := &RedisStreams{ + MaxRetries: 5, + RetryDelay: 200 * time.Millisecond, + MaxRetryDelay: 10 * time.Second, + ConnTimeout: 30 * time.Second, + } + + assert.Equal(t, 5, rs.getMaxRetries()) + assert.Equal(t, 200*time.Millisecond, rs.getRetryDelay()) + assert.Equal(t, 10*time.Second, rs.getMaxRetryDelay()) + assert.Equal(t, 30*time.Second, rs.getConnTimeout()) +} + +func TestRedisStreams_IsRetryableError(t *testing.T) { + rs := setupRedisStreamsWithRetry() + + testCases := []struct { + name string + err error + retryable bool + }{ + { + name: "nil error", + err: nil, + retryable: false, + }, + { + name: "connection refused", + err: errors.New("connection refused"), + retryable: true, + }, + { + name: "connection reset", + err: errors.New("connection reset by peer"), + retryable: true, + }, + { + name: "timeout error", + err: errors.New("i/o timeout"), + retryable: true, + }, + { + name: "network unreachable", + err: errors.New("network is unreachable"), + retryable: true, + }, + { + name: "broken pipe", + err: errors.New("broken pipe"), + retryable: true, + }, + { + name: "EOF error", + err: errors.New("EOF"), + retryable: true, + }, + { + name: "context deadline exceeded", + err: errors.New("context deadline exceeded"), + retryable: true, + }, + { + name: "context canceled", + err: errors.New("context canceled"), + retryable: true, + }, + { + name: "Redis LOADING", + err: errors.New("LOADING Redis is loading the dataset in memory"), + retryable: true, + }, + { + name: "Redis READONLY", + err: errors.New("READONLY You can't write against a read only replica."), + retryable: true, + }, + { + name: "Redis CLUSTERDOWN", + err: errors.New("CLUSTERDOWN Hash slot not served"), + retryable: true, + }, + { + name: "syntax error - not retryable", + err: errors.New("ERR syntax error"), + retryable: false, + }, + { + name: "wrong type error - not retryable", + err: errors.New("WRONGTYPE Operation against a key holding the wrong kind of value"), + retryable: false, + }, + { + name: "authentication error - not retryable", + err: errors.New("NOAUTH Authentication required"), + retryable: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := rs.isRetryableError(tc.err) + assert.Equal(t, tc.retryable, result, "Error: %v", tc.err) + }) + } +} + +func TestRedisStreams_IsConnectionError(t *testing.T) { + rs := setupRedisStreamsWithRetry() + + testCases := []struct { + name string + err error + isConnection bool + }{ + { + name: "nil error", + err: nil, + isConnection: false, + }, + { + name: "connection refused", + err: errors.New("connection refused"), + isConnection: true, + }, + { + name: "connection reset", + err: errors.New("connection reset by peer"), + isConnection: true, + }, + { + name: "network unreachable", + err: errors.New("network is unreachable"), + isConnection: true, + }, + { + name: "broken pipe", + err: errors.New("broken pipe"), + isConnection: true, + }, + { + name: "EOF error", + err: errors.New("EOF"), + isConnection: true, + }, + { + name: "syntax error - not connection", + err: errors.New("ERR syntax error"), + isConnection: false, + }, + { + name: "timeout - not connection", + err: errors.New("i/o timeout"), + isConnection: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := rs.isConnectionError(tc.err) + assert.Equal(t, tc.isConnection, result, "Error: %v", tc.err) + }) + } +} + +func TestRedisStreams_Publish_WithInvalidHost_ShouldRetry(t *testing.T) { + rs := setupRedisStreamsWithRetry() + rs.Host = "invalid-host:6379" // Use invalid host to trigger connection errors + rs.MaxRetries = 2 // Limit retries for faster test + rs.RetryDelay = 10 * time.Millisecond + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := rs.Publish(ctx, "test-channel", "test message") + + // Should fail after retries + assert.Error(t, err) + assert.Contains(t, err.Error(), "operation failed after 2 retries") +} + +func TestRedisStreams_Publish_WithCanceledContext(t *testing.T) { + rs := setupRedisStreamsWithRetry() + rs.Host = "invalid-host:6379" // Use invalid host to trigger retries + rs.MaxRetries = 5 + rs.RetryDelay = 100 * time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel context immediately to test cancellation handling + cancel() + + err := rs.Publish(ctx, "test-channel", "test message") + + // Should fail with context canceled error + assert.Error(t, err) + // Could be either context canceled directly or wrapped in retry error + assert.True(t, strings.Contains(err.Error(), "context canceled") || + strings.Contains(err.Error(), "operation failed after")) +} + +func TestRedisStreams_MetricsIntegration(t *testing.T) { + rs := setupRedisStreamsWithRetry() + + // Test that metrics registry can be set and retrieved + registry := metrics.NewRegistry("metrics_integration_test") + rs.SetMetricsRegistry(registry) + + assert.Equal(t, registry, rs.metricsRegistry) +} + +func TestRedisStreams_MetricsTracking_SafeWithNilRegistry(t *testing.T) { + rs := setupRedisStreamsWithRetry() + rs.metricsRegistry = nil + + // These should not panic with nil registry + rs.incrementCounter("test.counter") + rs.recordTimer("test.timer", 1.5) +} + +func TestRedisStreams_CreateClient_WithTimeouts(t *testing.T) { + rs := setupRedisStreamsWithRetry() + rs.ConnTimeout = 2 * time.Second + + client := rs.createClient() + defer client.Close() + + assert.NotNil(t, client) + // Note: go-redis client options are not easily inspectable, + // but we can verify the client was created without error +} + +func TestRedisStreams_AcknowledgeMessage_WithRetry(t *testing.T) { + // This test requires a running Redis instance + rs := setupRedisStreamsWithRetry() + ctx := context.Background() + + // Create a client to set up test data + client := redis.NewClient(&redis.Options{ + Addr: rs.Host, + Password: rs.Password, + DB: rs.Database, + }) + defer client.Close() + + streamName := "test-ack-stream" + consumerGroup := "test-ack-group" + + // Clean up + defer func() { + client.Del(ctx, streamName) + }() + + // Add a message to the stream + msgID, err := client.XAdd(ctx, &redis.XAddArgs{ + Stream: streamName, + Values: map[string]interface{}{ + "data": "test message", + }, + }).Result() + require.NoError(t, err) + + // Create consumer group + client.XGroupCreateMkStream(ctx, streamName, consumerGroup, "0") + + // Test acknowledge with valid message ID (should succeed) + err = rs.acknowledgeMessage(ctx, client, streamName, consumerGroup, msgID) + assert.NoError(t, err) + + // Test acknowledge with invalid message ID (should fail but not crash) + err = rs.acknowledgeMessage(ctx, client, streamName, consumerGroup, "invalid-id") + assert.Error(t, err) +} + +func TestRedisStreams_ExecuteWithRetry_NonRetryableError(t *testing.T) { + rs := setupRedisStreamsWithRetry() + ctx := context.Background() + + // Simulate a non-retryable error + operation := func(client *redis.Client) error { + return errors.New("ERR syntax error") // Non-retryable + } + + err := rs.executeWithRetry(ctx, operation) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-retryable error") + assert.Contains(t, err.Error(), "ERR syntax error") +} + +func TestRedisStreams_ExecuteWithRetry_SuccessAfterRetries(t *testing.T) { + rs := setupRedisStreamsWithRetry() + rs.RetryDelay = 1 * time.Millisecond // Fast retries for testing + // Set unique registry to avoid conflicts + rs.SetMetricsRegistry(metrics.NewRegistry("success_after_retries_test_" + time.Now().Format("20060102150405"))) + ctx := context.Background() + + attemptCount := 0 + operation := func(client *redis.Client) error { + attemptCount++ + if attemptCount < 3 { + return errors.New("connection refused") // Retryable + } + return nil // Success on third attempt + } + + err := rs.executeWithRetry(ctx, operation) + + assert.NoError(t, err) + assert.Equal(t, 3, attemptCount) +} + +func TestRedisStreams_ExecuteWithRetry_ExhaustRetries(t *testing.T) { + rs := setupRedisStreamsWithRetry() + rs.MaxRetries = 2 + rs.RetryDelay = 1 * time.Millisecond // Fast retries for testing + // Set unique registry to avoid conflicts + rs.SetMetricsRegistry(metrics.NewRegistry("exhaust_retries_test_" + time.Now().Format("20060102150405"))) + ctx := context.Background() + + attemptCount := 0 + operation := func(client *redis.Client) error { + attemptCount++ + return errors.New("connection refused") // Always retryable error + } + + err := rs.executeWithRetry(ctx, operation) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "operation failed after 2 retries") + assert.Equal(t, 3, attemptCount) // 1 initial + 2 retries +} + +func TestRedisStreams_CreateConsumerGroupWithRetry_BusyGroupExists(t *testing.T) { + rs := setupRedisStreamsWithRetry() + ctx := context.Background() + + // Create a client to set up test data + client := redis.NewClient(&redis.Options{ + Addr: rs.Host, + Password: rs.Password, + DB: rs.Database, + }) + defer client.Close() + + streamName := "test-busy-group-stream" + consumerGroup := "test-busy-group" + + // Clean up + defer func() { + client.Del(ctx, streamName) + }() + + // First call should succeed + err := rs.createConsumerGroupWithRetry(ctx, client, streamName, consumerGroup) + assert.NoError(t, err) + + // Second call should also succeed (BUSYGROUP error is handled) + err = rs.createConsumerGroupWithRetry(ctx, client, streamName, consumerGroup) + assert.NoError(t, err) +} + +func TestRedisStreams_ErrorHandling_ContextCancellation(t *testing.T) { + rs := setupRedisStreamsWithRetry() + rs.RetryDelay = 100 * time.Millisecond + // Set unique registry to avoid conflicts + rs.SetMetricsRegistry(metrics.NewRegistry("context_cancellation_test_" + time.Now().Format("20060102150405"))) + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + // Cancel context after a short delay + time.Sleep(50 * time.Millisecond) + cancel() + }() + + operation := func(client *redis.Client) error { + return errors.New("connection refused") // Retryable error + } + + err := rs.executeWithRetry(ctx, operation) + + assert.Error(t, err) + assert.Equal(t, context.Canceled, err) +} + +func TestRedisStreams_Subscribe_ErrorRecovery_Integration(t *testing.T) { + // Integration test - requires Redis to be running + rs := setupRedisStreamsWithRetry() + rs.MaxRetries = 1 // Limit retries for faster test + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + channel := "test-error-recovery" + defer cleanupRedisStream(rs.getStreamName(channel)) + + // Start subscriber + ch, err := rs.Subscribe(ctx, channel) + require.NoError(t, err) + + // Give some time for setup + time.Sleep(100 * time.Millisecond) + + // Publish a message + err = rs.Publish(ctx, channel, "test message") + require.NoError(t, err) + + // Should receive the message despite any internal error recovery + select { + case received := <-ch: + assert.Equal(t, "test message", received) + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for message") + } +} diff --git a/pkg/syncer/pubsub/redis_streams_test.go b/pkg/syncer/pubsub/redis_streams_test.go new file mode 100644 index 00000000..32a348f6 --- /dev/null +++ b/pkg/syncer/pubsub/redis_streams_test.go @@ -0,0 +1,342 @@ +/**************************************************************************** + * Copyright 2025 Optimizely, Inc. and contributors * + * * + * Licensed under the Apache License, Version 2.0 (the "License"); * + * you may not use this file except in compliance with the License. * + * You may obtain a copy of the License at * + * * + * http://www.apache.org/licenses/LICENSE-2.0 * + * * + * Unless required by applicable law or agreed to in writing, software * + * distributed under the License is distributed on an "AS IS" BASIS, * + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * + * See the License for the specific language governing permissions and * + * limitations under the License. * + ***************************************************************************/ + +package pubsub + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testRedisHost = "localhost:6379" + testDatabase = 0 + testPassword = "" +) + +func setupRedisStreams() *RedisStreams { + return &RedisStreams{ + Host: testRedisHost, + Password: testPassword, + Database: testDatabase, + MaxLen: 1000, + ConsumerGroup: "test-group", + ConsumerName: "test-consumer", + BatchSize: 10, + FlushInterval: 5 * time.Second, + } +} + +func cleanupRedisStream(streamName string) { + client := redis.NewClient(&redis.Options{ + Addr: testRedisHost, + Password: testPassword, + DB: testDatabase, + }) + defer client.Close() + + // Delete the stream and consumer group + client.Del(context.Background(), streamName) +} + +func TestRedisStreams_Publish_String(t *testing.T) { + rs := setupRedisStreams() + ctx := context.Background() + channel := "test-channel-string" + message := "test message" + + defer cleanupRedisStream(rs.getStreamName(channel)) + + err := rs.Publish(ctx, channel, message) + assert.NoError(t, err) + + // Verify message was added to stream + client := redis.NewClient(&redis.Options{ + Addr: testRedisHost, + Password: testPassword, + DB: testDatabase, + }) + defer client.Close() + + streamName := rs.getStreamName(channel) + messages, err := client.XRange(ctx, streamName, "-", "+").Result() + require.NoError(t, err) + assert.Len(t, messages, 1) + + // Check message content + data, exists := messages[0].Values["data"] + assert.True(t, exists) + assert.Equal(t, message, data) + + // Check timestamp exists + timestamp, exists := messages[0].Values["timestamp"] + assert.True(t, exists) + assert.NotNil(t, timestamp) +} + +func TestRedisStreams_Publish_JSON(t *testing.T) { + rs := setupRedisStreams() + ctx := context.Background() + channel := "test-channel-json" + + testObj := map[string]interface{}{ + "type": "notification", + "payload": "test data", + "id": 123, + } + + defer cleanupRedisStream(rs.getStreamName(channel)) + + err := rs.Publish(ctx, channel, testObj) + assert.NoError(t, err) + + // Verify message was serialized correctly + client := redis.NewClient(&redis.Options{ + Addr: testRedisHost, + Password: testPassword, + DB: testDatabase, + }) + defer client.Close() + + streamName := rs.getStreamName(channel) + messages, err := client.XRange(ctx, streamName, "-", "+").Result() + require.NoError(t, err) + assert.Len(t, messages, 1) + + // Check JSON was stored correctly + data, exists := messages[0].Values["data"] + assert.True(t, exists) + + var decoded map[string]interface{} + err = json.Unmarshal([]byte(data.(string)), &decoded) + require.NoError(t, err) + assert.Equal(t, testObj["type"], decoded["type"]) + assert.Equal(t, testObj["payload"], decoded["payload"]) + assert.Equal(t, float64(123), decoded["id"]) // JSON numbers become float64 +} + +func TestRedisStreams_Publish_ByteArray(t *testing.T) { + rs := setupRedisStreams() + ctx := context.Background() + channel := "test-channel-bytes" + message := []byte("test byte message") + + defer cleanupRedisStream(rs.getStreamName(channel)) + + err := rs.Publish(ctx, channel, message) + assert.NoError(t, err) + + // Verify message was stored as string + client := redis.NewClient(&redis.Options{ + Addr: testRedisHost, + Password: testPassword, + DB: testDatabase, + }) + defer client.Close() + + streamName := rs.getStreamName(channel) + messages, err := client.XRange(ctx, streamName, "-", "+").Result() + require.NoError(t, err) + assert.Len(t, messages, 1) + + data, exists := messages[0].Values["data"] + assert.True(t, exists) + assert.Equal(t, string(message), data) +} + +func TestRedisStreams_Subscribe_BasicFlow(t *testing.T) { + rs := setupRedisStreams() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + channel := "test-channel-subscribe" + defer cleanupRedisStream(rs.getStreamName(channel)) + + // Start subscriber + ch, err := rs.Subscribe(ctx, channel) + require.NoError(t, err) + + // Give subscriber time to set up + time.Sleep(100 * time.Millisecond) + + // Publish a message AFTER subscriber is ready + testMessage := "subscription test message" + err = rs.Publish(ctx, channel, testMessage) + require.NoError(t, err) + + // Wait for message + select { + case received := <-ch: + assert.Equal(t, testMessage, received) + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for message") + } +} + +func TestRedisStreams_Subscribe_MultipleMessages(t *testing.T) { + rs := setupRedisStreams() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + channel := "test-channel-multiple" + defer cleanupRedisStream(rs.getStreamName(channel)) + + // Start subscriber + ch, err := rs.Subscribe(ctx, channel) + require.NoError(t, err) + + // Give subscriber time to set up + time.Sleep(100 * time.Millisecond) + + // Publish multiple messages AFTER subscriber is ready + messages := []string{"message1", "message2", "message3"} + for _, msg := range messages { + err = rs.Publish(ctx, channel, msg) + require.NoError(t, err) + } + + // Collect received messages + var received []string + timeout := time.After(5 * time.Second) + + for i := 0; i < len(messages); i++ { + select { + case msg := <-ch: + received = append(received, msg) + case <-timeout: + t.Fatalf("Timeout waiting for message %d", i+1) + } + } + + assert.ElementsMatch(t, messages, received) +} + +func TestRedisStreams_HelperMethods(t *testing.T) { + rs := setupRedisStreams() + + // Test getStreamName + channel := "test-channel" + expected := "stream:test-channel" + assert.Equal(t, expected, rs.getStreamName(channel)) + + // Test getConsumerGroup + assert.Equal(t, "test-group", rs.getConsumerGroup()) + + // Test getConsumerGroup with empty value + rs.ConsumerGroup = "" + assert.Equal(t, "notifications", rs.getConsumerGroup()) + + // Test getConsumerName + rs.ConsumerName = "custom-consumer" + assert.Equal(t, "custom-consumer", rs.getConsumerName()) + + // Test getConsumerName with empty value (should generate unique name) + rs.ConsumerName = "" + name1 := rs.getConsumerName() + assert.Contains(t, name1, "consumer-") + // Note: getConsumerName generates the same name unless we create a new instance + + // Test getBatchSize + assert.Equal(t, 10, rs.getBatchSize()) + rs.BatchSize = 0 + assert.Equal(t, 10, rs.getBatchSize()) // Default + rs.BatchSize = -5 + assert.Equal(t, 10, rs.getBatchSize()) // Default for negative + + // Test getFlushInterval + rs.FlushInterval = 3 * time.Second + assert.Equal(t, 3*time.Second, rs.getFlushInterval()) + rs.FlushInterval = 0 + assert.Equal(t, 5*time.Second, rs.getFlushInterval()) // Default + rs.FlushInterval = -1 * time.Second + assert.Equal(t, 5*time.Second, rs.getFlushInterval()) // Default for negative +} + +func TestRedisStreams_Batching_Behavior(t *testing.T) { + rs := setupRedisStreams() + rs.BatchSize = 3 // Set small batch size for testing + rs.FlushInterval = 10 * time.Second // Long interval to test batch size trigger + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + channel := "test-channel-batching" + defer cleanupRedisStream(rs.getStreamName(channel)) + + // Start subscriber + ch, err := rs.Subscribe(ctx, channel) + require.NoError(t, err) + + // Publish messages to trigger batch + messages := []string{"batch1", "batch2", "batch3"} + for _, msg := range messages { + err = rs.Publish(ctx, channel, msg) + require.NoError(t, err) + } + + // Should receive all messages in one batch + var received []string + timeout := time.After(3 * time.Second) + + for len(received) < len(messages) { + select { + case msg := <-ch: + received = append(received, msg) + case <-timeout: + t.Fatalf("Timeout waiting for batched messages. Received %d out of %d", len(received), len(messages)) + } + } + + assert.ElementsMatch(t, messages, received) +} + +func TestRedisStreams_MaxLen_Configuration(t *testing.T) { + rs := setupRedisStreams() + rs.MaxLen = 2 // Very small max length + + ctx := context.Background() + channel := "test-channel-maxlen" + defer cleanupRedisStream(rs.getStreamName(channel)) + + // Publish more messages than MaxLen + messages := []string{"msg1", "msg2", "msg3", "msg4"} + for _, msg := range messages { + err := rs.Publish(ctx, channel, msg) + require.NoError(t, err) + } + + // Verify stream was trimmed + client := redis.NewClient(&redis.Options{ + Addr: testRedisHost, + Password: testPassword, + DB: testDatabase, + }) + defer client.Close() + + streamName := rs.getStreamName(channel) + length, err := client.XLen(ctx, streamName).Result() + require.NoError(t, err) + + // Should be approximately MaxLen (Redis uses approximate trimming) + // With APPROX, Redis may keep more entries than specified + assert.LessOrEqual(t, length, int64(10)) // Allow generous buffer for approximate trimming +} diff --git a/pkg/syncer/pubsub_test.go b/pkg/syncer/pubsub_test.go index 31b3dc1d..bcabab01 100644 --- a/pkg/syncer/pubsub_test.go +++ b/pkg/syncer/pubsub_test.go @@ -20,6 +20,7 @@ package syncer import ( "reflect" "testing" + "time" "github.com/optimizely/agent/config" "github.com/optimizely/agent/pkg/syncer/pubsub" @@ -260,6 +261,116 @@ func TestNewPubSub(t *testing.T) { want: nil, wantErr: true, }, + { + name: "Test with valid redis-streams config for notification", + args: args{ + conf: config.SyncConfig{ + Pubsub: map[string]interface{}{ + "redis": map[string]interface{}{ + "host": "localhost:6379", + "password": "", + "database": 0, + "batch_size": 20, + "flush_interval": "10s", + "max_retries": 5, + "retry_delay": "200ms", + "max_retry_delay": "10s", + "connection_timeout": "15s", + }, + }, + Notification: config.FeatureSyncConfig{ + Default: "redis-streams", + Enable: true, + }, + }, + flag: SyncFeatureFlagNotificaiton, + }, + want: &pubsub.RedisStreams{ + Host: "localhost:6379", + Password: "", + Database: 0, + BatchSize: 20, + FlushInterval: 10000000000, // 10s in nanoseconds + MaxRetries: 5, + RetryDelay: 200000000, // 200ms in nanoseconds + MaxRetryDelay: 10000000000, // 10s in nanoseconds + ConnTimeout: 15000000000, // 15s in nanoseconds + }, + wantErr: false, + }, + { + name: "Test with valid redis-streams config for datafile", + args: args{ + conf: config.SyncConfig{ + Pubsub: map[string]interface{}{ + "redis": map[string]interface{}{ + "host": "localhost:6379", + "password": "", + "database": 0, + }, + }, + Datafile: config.FeatureSyncConfig{ + Default: "redis-streams", + Enable: true, + }, + }, + flag: SycnFeatureFlagDatafile, + }, + want: &pubsub.RedisStreams{ + Host: "localhost:6379", + Password: "", + Database: 0, + BatchSize: 10, // default + FlushInterval: 5000000000, // 5s default in nanoseconds + MaxRetries: 3, // default + RetryDelay: 100000000, // 100ms default in nanoseconds + MaxRetryDelay: 5000000000, // 5s default in nanoseconds + ConnTimeout: 10000000000, // 10s default in nanoseconds + }, + wantErr: false, + }, + { + name: "Test with unsupported pubsub type", + args: args{ + conf: config.SyncConfig{ + Pubsub: map[string]interface{}{ + "redis": map[string]interface{}{ + "host": "localhost:6379", + "password": "", + "database": 0, + }, + }, + Notification: config.FeatureSyncConfig{ + Default: "unsupported-type", + Enable: true, + }, + }, + flag: SyncFeatureFlagNotificaiton, + }, + want: nil, + wantErr: true, + }, + { + name: "Test with invalid feature flag", + args: args{ + conf: config.SyncConfig{ + Pubsub: map[string]interface{}{ + "redis": map[string]interface{}{ + "host": "localhost:6379", + "password": "", + "database": 0, + }, + }, + Notification: config.FeatureSyncConfig{ + Default: "redis", + Enable: true, + }, + }, + flag: "invalid-flag", + }, + want: nil, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -274,3 +385,133 @@ func TestNewPubSub(t *testing.T) { }) } } + +func TestGetIntFromConfig(t *testing.T) { + tests := []struct { + name string + config map[string]interface{} + key string + defaultValue int + want int + }{ + { + name: "Valid int value", + config: map[string]interface{}{ + "test_key": 42, + }, + key: "test_key", + defaultValue: 10, + want: 42, + }, + { + name: "Missing key returns default", + config: map[string]interface{}{ + "other_key": 42, + }, + key: "test_key", + defaultValue: 10, + want: 10, + }, + { + name: "Invalid type returns default", + config: map[string]interface{}{ + "test_key": "not an int", + }, + key: "test_key", + defaultValue: 10, + want: 10, + }, + { + name: "Nil value returns default", + config: map[string]interface{}{ + "test_key": nil, + }, + key: "test_key", + defaultValue: 10, + want: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getIntFromConfig(tt.config, tt.key, tt.defaultValue) + if got != tt.want { + t.Errorf("getIntFromConfig() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetDurationFromConfig(t *testing.T) { + tests := []struct { + name string + config map[string]interface{} + key string + defaultValue time.Duration + want time.Duration + }{ + { + name: "Valid duration string", + config: map[string]interface{}{ + "test_key": "5s", + }, + key: "test_key", + defaultValue: 1 * time.Second, + want: 5 * time.Second, + }, + { + name: "Valid millisecond duration", + config: map[string]interface{}{ + "test_key": "100ms", + }, + key: "test_key", + defaultValue: 1 * time.Second, + want: 100 * time.Millisecond, + }, + { + name: "Missing key returns default", + config: map[string]interface{}{ + "other_key": "5s", + }, + key: "test_key", + defaultValue: 1 * time.Second, + want: 1 * time.Second, + }, + { + name: "Invalid duration string returns default", + config: map[string]interface{}{ + "test_key": "invalid duration", + }, + key: "test_key", + defaultValue: 1 * time.Second, + want: 1 * time.Second, + }, + { + name: "Non-string value returns default", + config: map[string]interface{}{ + "test_key": 123, + }, + key: "test_key", + defaultValue: 1 * time.Second, + want: 1 * time.Second, + }, + { + name: "Nil value returns default", + config: map[string]interface{}{ + "test_key": nil, + }, + key: "test_key", + defaultValue: 1 * time.Second, + want: 1 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getDurationFromConfig(tt.config, tt.key, tt.defaultValue) + if got != tt.want { + t.Errorf("getDurationFromConfig() = %v, want %v", got, tt.want) + } + }) + } +}