diff --git a/packages/client/lib/cluster/cluster-reconnection-tracker.spec.ts b/packages/client/lib/cluster/cluster-reconnection-tracker.spec.ts new file mode 100644 index 00000000000..5d38a51eb20 --- /dev/null +++ b/packages/client/lib/cluster/cluster-reconnection-tracker.spec.ts @@ -0,0 +1,156 @@ +import { strict as assert } from "node:assert"; +import ClusterReconnectionTracker from "./cluster-reconnection-tracker"; + +describe("ClusterReconnectionTracker", () => { + describe("validation", () => { + for (const strategy of [-1, 1.5, Number.NaN, true, null, "1000", {}]) { + it(`should throw when strategy is ${strategy}`, () => { + assert.throws( + () => new ClusterReconnectionTracker(strategy as never), + new TypeError( + "topologyRefreshOnReconnectionAttempt must be undefined, false, a non-negative integer, or a function", + ), + ); + }); + } + + it("should allow the default, false, 0, positive integer, and function strategies", () => { + assert.doesNotThrow(() => new ClusterReconnectionTracker()); + assert.doesNotThrow(() => new ClusterReconnectionTracker(false)); + assert.doesNotThrow(() => new ClusterReconnectionTracker(0)); + assert.doesNotThrow(() => new ClusterReconnectionTracker(1)); + assert.doesNotThrow( + () => new ClusterReconnectionTracker(() => undefined), + ); + }); + }); + + it("should not track anything when disabled", () => { + for (const strategy of [false, 0] as const) { + const state = new ClusterReconnectionTracker(strategy); + + assert.equal( + state.onReconnectionAttempt("client-1", "127.0.0.1:1", 100), + false, + ); + assert.deepEqual([...state.reconnectingAddresses], []); + assert.equal(state.firstReconnectionAt, undefined); + } + }); + + it("should default to refreshing after five seconds", () => { + const state = new ClusterReconnectionTracker(); + + assert.equal( + state.onReconnectionAttempt("client-1", "127.0.0.1:1", 100), + false, + ); + assert.equal(state.firstReconnectionAt, 100); + + assert.equal( + state.onReconnectionAttempt("client-1", "127.0.0.1:1", 5_099), + false, + ); + assert.equal(state.firstReconnectionAt, 100); + + assert.equal( + state.onReconnectionAttempt("client-1", "127.0.0.1:1", 5_100), + true, + ); + assert.equal(state.firstReconnectionAt, 5_100); + }); + + it("should track reconnecting clients by client id and remove them independently", () => { + const state = new ClusterReconnectionTracker(() => undefined); + + assert.equal( + state.onReconnectionAttempt("client-1", "127.0.0.1:1", 100), + false, + ); + assert.deepEqual([...state.reconnectingAddresses], ["127.0.0.1:1"]); + assert.equal(state.firstReconnectionAt, 100); + + assert.equal( + state.onReconnectionAttempt("client-2", "127.0.0.1:2", 150), + false, + ); + assert.deepEqual([...state.reconnectingAddresses].sort(), [ + "127.0.0.1:1", + "127.0.0.1:2", + ]); + assert.equal(state.firstReconnectionAt, 100); + + state.removeClient("client-1"); + assert.deepEqual([...state.reconnectingAddresses], ["127.0.0.1:2"]); + assert.equal(state.firstReconnectionAt, 100); + + state.removeClient("client-2"); + assert.deepEqual([...state.reconnectingAddresses], []); + assert.equal(state.firstReconnectionAt, undefined); + }); + + it("should clear all reconnecting state", () => { + const state = new ClusterReconnectionTracker(() => undefined); + + state.onReconnectionAttempt("client-1", "127.0.0.1:1", 100); + state.onReconnectionAttempt("client-2", "127.0.0.1:2", 150); + state.clear(); + + assert.deepEqual([...state.reconnectingAddresses], []); + assert.equal(state.firstReconnectionAt, undefined); + }); + + it("should return true when enough time has elapsed and reset the timestamp", () => { + const state = new ClusterReconnectionTracker(50); + + assert.equal( + state.onReconnectionAttempt("client-1", "127.0.0.1:1", 100), + false, + ); + assert.equal(state.firstReconnectionAt, 100); + + assert.equal( + state.onReconnectionAttempt("client-1", "127.0.0.1:1", 149), + false, + ); + assert.equal(state.firstReconnectionAt, 100); + + assert.equal( + state.onReconnectionAttempt("client-1", "127.0.0.1:1", 150), + true, + ); + assert.equal(state.firstReconnectionAt, 150); + }); + + it("should skip refresh when the function strategy returns false", () => { + const state = new ClusterReconnectionTracker(() => false); + + assert.equal( + state.onReconnectionAttempt("client-1", "127.0.0.1:1", 100), + false, + ); + assert.deepEqual([...state.reconnectingAddresses], ["127.0.0.1:1"]); + assert.equal(state.firstReconnectionAt, 100); + }); + + it("should throw when the function strategy throws", () => { + const error = new Error("strategy failed"); + const state = new ClusterReconnectionTracker(() => { + throw error; + }); + + assert.throws( + () => state.onReconnectionAttempt("client-1", "127.0.0.1:1", 100), + error, + ); + }); + + it("should throw when the function strategy returns an invalid value", () => { + const state = new ClusterReconnectionTracker(() => -1); + + assert.throws( + () => state.onReconnectionAttempt("client-1", "127.0.0.1:1", 100), + /topologyRefreshOnReconnectionAttempt should return/, + ); + }); +}); diff --git a/packages/client/lib/cluster/cluster-reconnection-tracker.ts b/packages/client/lib/cluster/cluster-reconnection-tracker.ts new file mode 100644 index 00000000000..462f07d9443 --- /dev/null +++ b/packages/client/lib/cluster/cluster-reconnection-tracker.ts @@ -0,0 +1,130 @@ +import type { ClusterTopologyRefreshOnReconnectionAttemptStrategy } from './index'; + +/** + * Tracks which cluster node clients are currently reconnecting and decides when + * to trigger a cluster topology refresh based on a configurable strategy. + * + * The strategy can be: + * - `undefined` - uses the default delay (5 seconds) + * - `false` or `0` - disables topology refresh on reconnection + * - a positive integer - delay in ms after the first reconnection attempt before refreshing + * - a function - custom logic receiving the timestamp of the first reconnection attempt, + * returning a delay or `false`/`undefined` to skip + * + * After the delay elapses, {@link onReconnectionAttempt} returns `true` once to signal + * that a refresh should be scheduled, then resets the timer. + */ +export default class ClusterReconnectionTracker { + /** Default delay (ms) before triggering a topology refresh after reconnection starts */ + static #DEFAULT_TOPOLOGY_REFRESH_ON_RECONNECTION_ATTEMPT = 5_000; + + readonly #strategy?: ClusterTopologyRefreshOnReconnectionAttemptStrategy; + /** Maps client ID to its node address for clients currently in a reconnecting state */ + readonly #reconnectingClients = new Map(); + /** Timestamp of the first reconnection attempt in the current reconnection cycle */ + #firstReconnectionAt?: number; + + /** + * Validates that a strategy value is acceptable before use. + * @throws If the strategy is not supported + */ + #validate(strategy?: ClusterTopologyRefreshOnReconnectionAttemptStrategy) { + if ( + strategy === undefined || + strategy === false || + typeof strategy === 'function' || + ( + typeof strategy === 'number' && + Number.isInteger(strategy) && + strategy >= 0 + ) + ) { + return; + } + + throw new TypeError('topologyRefreshOnReconnectionAttempt must be undefined, false, a non-negative integer, or a function'); + } + + constructor(strategy?: ClusterTopologyRefreshOnReconnectionAttemptStrategy) { + this.#validate(strategy); + this.#strategy = strategy; + } + + get reconnectingAddresses() { + return new Set(this.#reconnectingClients.values()); + } + + get firstReconnectionAt() { + return this.#firstReconnectionAt; + } + + /** + * Records a reconnection attempt for the given client and evaluates whether + * the configured delay has elapsed since the first attempt in this cycle. + * + * @returns `true` if a topology refresh should be triggered, `false` otherwise + * @throws If a user-supplied strategy function returns an invalid value + */ + onReconnectionAttempt(clientId: string, address: string, now = Date.now()) { + if (this.#strategy === false || this.#strategy === 0) { + return false; + } + + this.#reconnectingClients.set(clientId, address); + this.#firstReconnectionAt ??= now; + + const delay = this.#getDelay(this.#firstReconnectionAt); + if (delay === undefined || now - this.#firstReconnectionAt < delay) { + return false; + } + + this.#firstReconnectionAt = now; + return true; + } + + /** Removes a client from tracking (e.g. when it reconnects successfully or disconnects) */ + removeClient(clientId: string) { + if (!this.#reconnectingClients.delete(clientId)) return; + + this.#clearTimestampIfClean(); + } + + /** Resets all tracking state (e.g. on cluster disconnect or destroy) */ + clear() { + this.#reconnectingClients.clear(); + this.#firstReconnectionAt = undefined; + } + + /** + * Evaluates the configured strategy to determine the delay before a topology refresh. + * @returns The delay in ms, or `undefined` if no refresh should occur + */ + #getDelay(firstReconnectionAt: number) { + if (this.#strategy === undefined) { + return ClusterReconnectionTracker.#DEFAULT_TOPOLOGY_REFRESH_ON_RECONNECTION_ATTEMPT; + } + + if (this.#strategy === false) { + return; + } + + if (typeof this.#strategy === 'number') { + return this.#strategy; + } + + const delay = this.#strategy(firstReconnectionAt); + if (delay === false || delay === undefined || delay === 0) return; + + if (!Number.isInteger(delay) || delay < 0) { + throw new TypeError(`topologyRefreshOnReconnectionAttempt should return \`false | undefined | number\`, got ${delay} instead`); + } + + return delay; + } + + #clearTimestampIfClean() { + if (this.#reconnectingClients.size === 0) { + this.#firstReconnectionAt = undefined; + } + } +} diff --git a/packages/client/lib/cluster/cluster-slots.spec.ts b/packages/client/lib/cluster/cluster-slots.spec.ts index 86e4ecc06aa..8585b9a6780 100644 --- a/packages/client/lib/cluster/cluster-slots.spec.ts +++ b/packages/client/lib/cluster/cluster-slots.spec.ts @@ -6,7 +6,7 @@ import RedisClusterSlots from './cluster-slots'; describe('RedisClusterSlots', () => { describe('initialization', () => { describe('clientSideCache validation', () => { - const mockEmit = ((_event: string | symbol, ..._args: any[]): boolean => true) as EventEmitter['emit']; + const mockEmit: EventEmitter['emit'] = () => true; const clientSideCacheConfig = { ttl: 0, maxEntries: 0 }; const rootNodes: Array = [ { socket: { host: 'localhost', port: 30001 } } diff --git a/packages/client/lib/cluster/cluster-slots.ts b/packages/client/lib/cluster/cluster-slots.ts index c043dd7dc53..98a42420ae8 100644 --- a/packages/client/lib/cluster/cluster-slots.ts +++ b/packages/client/lib/cluster/cluster-slots.ts @@ -1,4 +1,4 @@ -import { RedisClusterClientOptions, RedisClusterOptions } from '.'; +import type { RedisClusterClientOptions, RedisClusterOptions } from '.'; import { RootNodesUnavailableError } from '../errors'; import RedisClient, { RedisClientOptions, RedisClientType } from '../client'; import { EventEmitter } from 'node:stream'; @@ -9,6 +9,7 @@ import { RedisSocketOptions } from '../client/socket'; import { BasicPooledClientSideCache, PooledClientSideCacheProvider } from '../client/cache'; import { SMIGRATED_EVENT, SMigratedEvent, dbgMaintenance } from '../client/enterprise-maintenance-manager'; import { ClientRole } from '../client/identity'; +import ClusterReconnectionTracker from './cluster-reconnection-tracker'; interface NodeAddress { host: string; @@ -112,6 +113,7 @@ export default class RedisClusterSlots< readonly #clientFactory; readonly #emit: EventEmitter['emit']; readonly #clusterClientId: string; + readonly #reconnectionTracker: ClusterReconnectionTracker; slots = new Array>(RedisClusterSlots.#SLOTS); masters = new Array>(); replicas = new Array>(); @@ -119,6 +121,7 @@ export default class RedisClusterSlots< pubSubNode?: PubSubNode; clientSideCache?: PooledClientSideCacheProvider; smigratedSeqIdsSeen = new Set; + #topologyRefreshPromise?: Promise; #isOpen = false; @@ -140,6 +143,7 @@ export default class RedisClusterSlots< this.#validateOptions(options); this.#options = options; this.#clusterClientId = clusterClientId; + this.#reconnectionTracker = new ClusterReconnectionTracker(options.topologyRefreshOnReconnectionAttemptStrategy); if (options?.clientSideCache) { if (options.clientSideCache instanceof PooledClientSideCacheProvider) { @@ -169,7 +173,7 @@ export default class RedisClusterSlots< } async #discoverWithRootNodes() { - let start = Math.floor(Math.random() * this.#options.rootNodes.length); + const start = Math.floor(Math.random() * this.#options.rootNodes.length); for (let i = start; i < this.#options.rootNodes.length; i++) { if (!this.#isOpen) throw new Error('Cluster closed'); if (await this.#discover(this.#options.rootNodes[i])) { @@ -225,6 +229,7 @@ export default class RedisClusterSlots< const channelsListeners = this.pubSubNode.client.getPubSubListeners(PUBSUB_TYPE.CHANNELS), patternsListeners = this.pubSubNode.client.getPubSubListeners(PUBSUB_TYPE.PATTERNS); + this.#reconnectionTracker.removeClient(this.pubSubNode.client._clientId); this.pubSubNode.client.destroy(); if (channelsListeners.size || patternsListeners.size) { @@ -241,13 +246,24 @@ export default class RedisClusterSlots< for (const [address, node] of this.nodeByAddress.entries()) { if (addressesInUse.has(address)) continue; + const { pubSub } = node as MasterNode; + if (pubSub) { + const listeners = pubSub.client._getQueue().removeAllPubSubListeners(); + if (listeners.CHANNELS.size || listeners.PATTERNS.size || listeners.SHARDED.size) { + this.#emit(RESUBSCRIBE_LISTENERS_EVENT, listeners); + } + } + if (node.client) { + this.#reconnectionTracker.removeClient(node.client._clientId); node.client.destroy(); + node.client = undefined; } - const { pubSub } = node as MasterNode; if (pubSub) { + this.#reconnectionTracker.removeClient(pubSub.client._clientId); pubSub.client.destroy(); + (node as MasterNode).pubSub = undefined; } this.nodeByAddress.delete(address); @@ -431,17 +447,20 @@ export default class RedisClusterSlots< this.pubSubNode = undefined; } + this.#reconnectionTracker.removeClient(oldPubSubClient._clientId); oldPubSubClient.destroy(); } // Destroy source connections (use destroy() instead of close() since the node is being removed // and close() can hang if the server is not responding) - sourceNode.client?.destroy(); - if ('pubSub' in sourceNode) { - sourceNode.pubSub?.client.destroy(); + this.#reconnectionTracker.removeClient(sourceNode.client?._clientId); + sourceNode.client.destroy(); + if ('pubSub' in sourceNode && sourceNode.pubSub) { + this.#reconnectionTracker.removeClient(sourceNode.pubSub.client._clientId); + sourceNode.pubSub.client.destroy(); } } - } catch (err: any) { + } catch (err: unknown) { dbgMaintenance(`[CSlots]: Error during SMIGRATED handling for source ${sourceAddress}: ${err}`); // Ensure we unpause source on error to prevent deadlock sourceNode.client?._unpause(); @@ -484,6 +503,15 @@ export default class RedisClusterSlots< } } + #nodeClientOptions(node: NodeAddress & { address: string }): RedisClusterClientOptions { + return { + socket: this.#getNodeAddress(node.address) ?? { + host: node.host, + port: node.port + } + }; + } + #clientOptionsDefaults(options?: RedisClientOptions) { if (!this.#options.defaults) return options; @@ -546,7 +574,9 @@ export default class RedisClusterSlots< host: socket.host, port: socket.port, }); + const address = node.address; const emit = this.#emit; + let wasReady = false; const client = this.#clientFactory( this.#clientOptionsDefaults({ clientSideCache: this.clientSideCache, RESP: this.#options.RESP, @@ -556,10 +586,23 @@ export default class RedisClusterSlots< client._setIdentity(ClientRole.CLUSTER_NODE, this.#clusterClientId); client .on('error', error => emit('node-error', error, clientInfo)) - .on('reconnecting', () => emit('node-reconnecting', clientInfo)) + .on('reconnecting', () => { + emit('node-reconnecting', clientInfo); + + if (!wasReady) return; + + this.#onNodeReconnectionAttempt(client._clientId, address); + }) + .on('ready', () => { + wasReady = true; + this.#reconnectionTracker.removeClient(client._clientId); + }) .once('ready', () => emit('node-ready', clientInfo)) .once('connect', () => emit('node-connect', clientInfo)) - .once('end', () => emit('node-disconnect', clientInfo)) + .once('end', () => { + this.#reconnectionTracker.removeClient(client._clientId); + emit('node-disconnect', clientInfo); + }) .on(SMIGRATED_EVENT, this.#handleSmigrated) .on('__MOVED', async (allPubSubListeners: PubSubListeners) => { await this.rediscover(client); @@ -588,20 +631,92 @@ export default class RedisClusterSlots< #runningRediscoverPromise?: Promise; - async rediscover(startWith: RedisClientType): Promise { - this.#runningRediscoverPromise ??= this.#rediscover(startWith) + async rediscover(startWith?: RedisClientType, excludedAddresses?: ReadonlySet): Promise { + this.#runningRediscoverPromise ??= this.#rediscover(startWith, excludedAddresses) .finally(() => { this.#runningRediscoverPromise = undefined }); return this.#runningRediscoverPromise; } - async #rediscover(startWith: RedisClientType): Promise { - if (await this.#discover(startWith.options!)) return; + async #rediscover(startWith?: RedisClientType, excludedAddresses?: ReadonlySet): Promise { + if (startWith && await this.#discover(startWith.options!)) return; + + if (await this.#discoverWithKnownNodes(excludedAddresses)) return; return this.#discoverWithRootNodes(); } + async #discoverWithKnownNodes(excludedAddresses?: ReadonlySet) { + const candidates: Array> = []; + const deferredCandidates: Array> = []; + const seen = new Set(); + + for (const nodes of [this.masters, this.replicas]) { + for (const node of nodes) { + if (excludedAddresses?.has(node.address) || seen.has(node.address)) continue; + + seen.add(node.address); + + if (node.client?.isReady) { + candidates.push(node); + } else { + deferredCandidates.push(node); + } + } + } + + return ( + await this.#discoverWithKnownNodeCandidates(candidates) || + await this.#discoverWithKnownNodeCandidates(deferredCandidates) + ); + } + + async #discoverWithKnownNodeCandidates(candidates: Array>) { + if (!candidates.length) { + return false; + } + + const start = Math.floor(Math.random() * candidates.length); + for (let i = 0; i < candidates.length; i++) { + + if (!this.#isOpen) { + continue; + } + + const candidate = candidates[(start + i) % candidates.length]; + if (await this.#discover(this.#nodeClientOptions(candidate))) { + return true; + } + } + + return false; + } + + #onNodeReconnectionAttempt(clientId: string, address: string) { + let shouldRefresh: boolean; + try { + shouldRefresh = this.#reconnectionTracker.onReconnectionAttempt(clientId, address); + } catch (err) { + this.#emit('error', err); + return; + } + + if (shouldRefresh) { + this.#scheduleTopologyRefresh(this.#reconnectionTracker.reconnectingAddresses); + } + } + + #scheduleTopologyRefresh(excludedAddresses: ReadonlySet) { + if (!this.#isOpen || this.#topologyRefreshPromise) return; + + this.#topologyRefreshPromise = this.rediscover(undefined, new Set(excludedAddresses)) + .catch(err => this.#emit('error', err)) + .finally(() => { + this.#topologyRefreshPromise = undefined; + }); + } + /** * @deprecated Use `close` instead. */ @@ -634,6 +749,7 @@ export default class RedisClusterSlots< this.#resetSlots(); this.nodeByAddress.clear(); + this.#reconnectionTracker.clear(); this.#emit('disconnect'); } @@ -670,6 +786,7 @@ export default class RedisClusterSlots< this.#resetSlots(); this.nodeByAddress.clear(); + this.#reconnectionTracker.clear(); await Promise.allSettled(promises); this.#emit('disconnect'); @@ -816,6 +933,7 @@ export default class RedisClusterSlots< await unsubscribe(client); if (!client.isPubSubActive) { + this.#reconnectionTracker.removeClient(client._clientId); client.destroy(); this.pubSubNode = undefined; } @@ -873,6 +991,7 @@ export default class RedisClusterSlots< await unsubscribe(client); if (!client.isPubSubActive) { + this.#reconnectionTracker.removeClient(client._clientId); client.destroy(); master.pubSub = undefined; } diff --git a/packages/client/lib/cluster/index.ts b/packages/client/lib/cluster/index.ts index fbdebad16a7..1c994d03314 100644 --- a/packages/client/lib/cluster/index.ts +++ b/packages/client/lib/cluster/index.ts @@ -16,6 +16,11 @@ import SingleEntryCache from '../single-entry-cache' import { publish, CHANNELS } from '../client/tracing'; import { ClientIdentity, ClientRole, generateClusterClientId } from '../client/identity'; +export type ClusterTopologyRefreshOnReconnectionAttemptStrategy = + false | + number | + ((firstReconnectionAt: number) => false | number | undefined); + type WithCommands< RESP extends RespVersions, TYPE_MAPPING extends TypeMapping @@ -72,6 +77,15 @@ export interface RedisClusterOptions< * The maximum number of times a command will be redirected due to `MOVED` or `ASK` errors. */ maxCommandRedirections?: number; + /** + * The number of milliseconds after the first post-ready node reconnection attempt + * before background cluster topology refreshes are triggered. Omitted or `undefined` + * uses the default delay of `5000`. + * Use `false` or `0` to disable reconnect-triggered topology refreshes. A function can + * return the delay dynamically, or `false`/`undefined`/`0` to skip the refresh attempt. + * Concurrent refreshes are de-duplicated. + */ + topologyRefreshOnReconnectionAttemptStrategy?: ClusterTopologyRefreshOnReconnectionAttemptStrategy; /** * Mapping between the addresses in the cluster (see `CLUSTER SHARDS`) and the addresses the client should connect to * Useful when the cluster is running on another network @@ -129,14 +143,12 @@ export type RedisClusterType< WithScripts ); -export interface ClusterCommandOptions< +export type ClusterCommandOptions< TYPE_MAPPING extends TypeMapping = TypeMapping // POLICIES extends CommandPolicies = CommandPolicies -> extends CommandOptions { - // policies?: POLICIES; -} +> = CommandOptions; -type ProxyCluster = RedisCluster; +type ProxyCluster = RedisCluster; type NamespaceProxyCluster = { _self: ProxyCluster }; @@ -216,6 +228,7 @@ export default class RedisCluster< }; } + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- cache stores dynamically generated cluster subclasses static #SingleEntryCache = new SingleEntryCache(); static factory< @@ -535,7 +548,7 @@ export default class RedisCluster< MULTI(routing?: RedisArgument) { type Multi = new (...args: ConstructorParameters) => RedisClusterMultiCommandType<[], M, F, S, RESP, TYPE_MAPPING>; - return new ((this as any).Multi as Multi)( + return new (this as this & { Multi: Multi }).Multi( async (firstKey, isReadonly, commands) => { const { client } = await this._self._slots.getClientAndSlotNumber(firstKey, isReadonly); return client._executeMulti(commands); diff --git a/packages/client/lib/tests/test-scenario/fault-injector-client.ts b/packages/client/lib/tests/test-scenario/fault-injector-client.ts index 4fd35fa446c..6647f6497a9 100644 --- a/packages/client/lib/tests/test-scenario/fault-injector-client.ts +++ b/packages/client/lib/tests/test-scenario/fault-injector-client.ts @@ -11,7 +11,10 @@ export type ActionType = | "execute_rladmin_command" | "migrate" | "bind" - | "update_cluster_config"; + | "update_cluster_config" + | "node_failure" + | "proxy_failure" + | "shard_failure"; export interface ActionRequest { type: ActionType; @@ -143,7 +146,7 @@ export class FaultInjectorClient { async #request( method: string, path: string, - body?: Object | string + body?: object | string ): Promise { const url = `${this.baseUrl}${path}`; const headers: Record = { diff --git a/packages/client/lib/tests/test-scenario/sharded-pubsub/spubsub.e2e.ts b/packages/client/lib/tests/test-scenario/sharded-pubsub/spubsub.e2e.ts index 46ef252da8e..88d8bf6ae65 100644 --- a/packages/client/lib/tests/test-scenario/sharded-pubsub/spubsub.e2e.ts +++ b/packages/client/lib/tests/test-scenario/sharded-pubsub/spubsub.e2e.ts @@ -1,362 +1,582 @@ -import type { Cluster, TestConfig } from "./utils/test.util"; -import { createClusterTestClient, getConfig } from "./utils/test.util"; -import { FaultInjectorClient } from "../fault-injector-client"; +import type { ActionRequest } from "@redis/test-utils/lib/fault-injector"; +import testUtils from "../../../test-utils"; import { TestCommandRunner } from "./utils/command-runner"; import { CHANNELS, CHANNELS_BY_SLOT } from "./utils/test.util"; import { MessageTracker } from "./utils/message-tracker"; import assert from "node:assert"; import { setTimeout } from "node:timers/promises"; -describe("Sharded Pub/Sub E2E", () => { - let faultInjectorClient: FaultInjectorClient; - let config: TestConfig; - - before(() => { - config = getConfig(); - - faultInjectorClient = new FaultInjectorClient(config.faultInjectorUrl); - }); - - describe("Single Subscriber", () => { - let subscriber: Cluster; - let publisher: Cluster; - let messageTracker: MessageTracker; - - beforeEach(async () => { - messageTracker = new MessageTracker(CHANNELS); - subscriber = createClusterTestClient(config.clientConfig, {}); - publisher = createClusterTestClient(config.clientConfig, {}); - await Promise.all([subscriber.connect(), publisher.connect()]); - }); - - afterEach(async () => { - await Promise.all([subscriber.quit(), publisher.quit()]); - }); - - it("should receive messages published to multiple channels", async () => { - for (const channel of CHANNELS) { - await subscriber.sSubscribe(channel, (_msg, channel) => - messageTracker.incrementReceived(channel), - ); - } - const { controller, result } = - TestCommandRunner.publishMessagesUntilAbortSignal( - publisher, - CHANNELS, - messageTracker, - ); - // Wait for 10 seconds, while publishing messages - await setTimeout(10_000); - controller.abort(); +const TEST_TIMEOUT = 180_000; +const CLUSTER_INDEX = 0; +const FAILURE_ACTION_TIMEOUT_MS = 120_000; +const FAILURE_PUBLISH_WARMUP_MS = 1_000; +const FAILURE_RECOVERY_WAIT_MS = 2_000; +const PUBLISH_VERIFICATION_DURATION_MS = 10_000; +const UNSUBSCRIBE_VERIFICATION_DURATION_MS = 5_000; +const POST_RECOVERY_PUBLISH_DURATION_MS = 10_000; +const POST_RECOVERY_DELIVERY_RATIO = 0.9; +const RECOVERY_ASSERTION_TIMEOUT_MS = 45_000; +const RECOVERY_ASSERTION_INTERVAL_MS = 100; + +const FAILURE_CASES = [ + { + name: "should resume publishing and receiving after failover", + action: { + type: "failover", + parameters: { + cluster_index: CLUSTER_INDEX, + }, + }, + }, + { + name: "should resume publishing and receiving after rebooting a cluster node", + action: { + type: "node_failure", + parameters: { + cluster_index: CLUSTER_INDEX, + node_id: 1, + method: "reboot", + }, + }, + }, + { + name: "should resume publishing and receiving after restarting the database proxy", + action: { + type: "proxy_failure", + parameters: { + cluster_index: CLUSTER_INDEX, + action: "restart", + }, + }, + }, + { + name: "should resume publishing and receiving after a shard failure", + action: { + type: "shard_failure", + parameters: { + cluster_index: CLUSTER_INDEX, + }, + }, + }, +] as const satisfies ReadonlyArray<{ + name: string; + action: Readonly; +}>; + +function getChannelStatsOrThrow(messageTracker: MessageTracker, channel: string) { + const stats = messageTracker.getChannelStats(channel); + assert.ok(stats, `Expected stats for channel ${channel}`); + return stats; +} + +type BackgroundPublisher = Parameters< + typeof TestCommandRunner.publishMessagesUntilAbortSignal +>[0]; +type BackgroundPublishOptions = Parameters< + typeof TestCommandRunner.publishMessagesUntilAbortSignal +>[3]; + +async function createConnectedDuplicate< + T extends { + duplicate(): T; + connect(): Promise; + }, +>(client: T): Promise { + const duplicate = client.duplicate(); + await duplicate.connect(); + return duplicate; +} + +async function withBackgroundPublishing( + client: BackgroundPublisher, + channels: string[], + messageTracker: MessageTracker, + callback: () => Promise, + options?: BackgroundPublishOptions, +): Promise { + const { controller, result } = + TestCommandRunner.publishMessagesUntilAbortSignal( + client, + channels, + messageTracker, + options, + ); + + let callbackError: unknown; + + try { + return await callback(); + } catch (error) { + callbackError = error; + throw error; + } finally { + controller.abort(); + + try { await result; - - for (const channel of CHANNELS) { - assert.strictEqual( - messageTracker.getChannelStats(channel)?.received, - messageTracker.getChannelStats(channel)?.sent, - ); - } - }); - - it("should resume publishing and receiving after failover", async () => { - for (const channel of CHANNELS) { - await subscriber.sSubscribe(channel, (_msg, channel) => { - messageTracker.incrementReceived(channel); - }); + } catch (error) { + if (callbackError === undefined) { + // eslint-disable-next-line no-unsafe-finally -- intentionally surface publisher failure when callback succeeded + throw error; } + } + } +} + +async function waitForAssertion( + assertion: () => void, + timeoutMs: number, + intervalMs = RECOVERY_ASSERTION_INTERVAL_MS, +) { + const start = Date.now(); + let lastError: unknown; + + while (Date.now() - start < timeoutMs) { + try { + assertion(); + return; + } catch (error) { + lastError = error; + await setTimeout(intervalMs); + } + } + + throw lastError instanceof Error + ? lastError + : new Error(`Assertion did not pass within ${timeoutMs}ms`); +} - // Trigger failover twice - for (let i = 0; i < 2; i++) { - // Start publishing messages - const { controller: publishAbort, result: publishResult } = - TestCommandRunner.publishMessagesUntilAbortSignal( +describe("Sharded Pub/Sub E2E", () => { + describe("Single Subscriber", () => { + testUtils.testWithRECluster( + "should receive messages published to multiple channels", + async (cluster) => { + const messageTracker = new MessageTracker(CHANNELS); + const publisher = cluster; + const subscriber = await createConnectedDuplicate(cluster); + + try { + for (const channel of CHANNELS) { + await subscriber.sSubscribe(channel, (_message, receivedChannel) => + messageTracker.incrementReceived(receivedChannel), + ); + } + + await withBackgroundPublishing( publisher, CHANNELS, messageTracker, - ); - - // Trigger failover during publishing - const { action_id: failoverActionId } = - await faultInjectorClient.triggerAction({ - type: "failover", - parameters: { - bdb_id: config.clientConfig.bdbId.toString(), - cluster_index: 0, + async () => { + await setTimeout(PUBLISH_VERIFICATION_DURATION_MS); }, - }); - - // Wait for failover to complete - await faultInjectorClient.waitForAction(failoverActionId); - - publishAbort.abort(); - await publishResult; - - for (const channel of CHANNELS) { - const sent = messageTracker.getChannelStats(channel)!.sent; - const received = messageTracker.getChannelStats(channel)!.received; - - assert.ok( - received <= sent, - `Channel ${channel}: received (${received}) should be <= sent (${sent})`, ); - } - // Wait for 2 seconds before resuming publishing - await setTimeout(2_000); - messageTracker.reset(); - - const { - controller: afterFailoverController, - result: afterFailoverResult, - } = TestCommandRunner.publishMessagesUntilAbortSignal( - publisher, - CHANNELS, - messageTracker, - ); - - await setTimeout(10_000); - afterFailoverController.abort(); - await afterFailoverResult; - - for (const channel of CHANNELS) { - const sent = messageTracker.getChannelStats(channel)!.sent; - const received = messageTracker.getChannelStats(channel)!.received; - assert.ok(sent > 0, `Channel ${channel} should have sent messages`); - assert.ok( - received > 0, - `Channel ${channel} should have received messages`, - ); - assert.strictEqual( - messageTracker.getChannelStats(channel)!.received, - messageTracker.getChannelStats(channel)!.sent, - `Channel ${channel} received (${received}) should equal sent (${sent}) once resumed after failover`, - ); + for (const channel of CHANNELS) { + const { sent, received } = getChannelStatsOrThrow( + messageTracker, + channel, + ); + + assert.strictEqual( + received, + sent, + `Channel ${channel} should receive every published message`, + ); + } + } finally { + subscriber.destroy(); } - } - }); - - it("should NOT receive messages after sunsubscribe", async () => { - for (const channel of CHANNELS) { - await subscriber.sSubscribe(channel, (_msg, channel) => messageTracker.incrementReceived(channel)); - } - - const { controller, result } = - TestCommandRunner.publishMessagesUntilAbortSignal( - publisher, - CHANNELS, - messageTracker, - ); - - // Wait for 5 seconds, while publishing messages - await setTimeout(5_000); - controller.abort(); - await result; + }, + { testTimeout: TEST_TIMEOUT }, + ); + + for (const failureCase of FAILURE_CASES) { + testUtils.testWithRECluster( + failureCase.name, + async (cluster, faultInjectorClient) => { + const messageTracker = new MessageTracker(CHANNELS); + const publisher = cluster; + const subscriber = await createConnectedDuplicate(cluster); + + try { + for (const channel of CHANNELS) { + await subscriber.sSubscribe(channel, (_message, receivedChannel) => { + messageTracker.incrementReceived(receivedChannel); + }); + } + + await withBackgroundPublishing( + publisher, + CHANNELS, + messageTracker, + async () => { + await setTimeout(FAILURE_PUBLISH_WARMUP_MS); + + await faultInjectorClient.triggerAction(failureCase.action, { + maxWaitTimeMs: FAILURE_ACTION_TIMEOUT_MS, + }); + }, + ); + + const sentDuringFailure = CHANNELS.reduce( + (sum, channel) => + sum + getChannelStatsOrThrow(messageTracker, channel).sent, + 0, + ); + const receivedDuringFailure = CHANNELS.reduce( + (sum, channel) => + sum + getChannelStatsOrThrow(messageTracker, channel).received, + 0, + ); + + assert.ok( + sentDuringFailure > 0, + "Expected messages to be published during the failure scenario", + ); + assert.ok( + receivedDuringFailure > 0, + "Expected messages to be received during the failure scenario", + ); + + for (const channel of CHANNELS) { + const { sent, received } = getChannelStatsOrThrow( + messageTracker, + channel, + ); + + assert.ok( + received <= sent, + `Channel ${channel}: received (${received}) should be <= sent (${sent})`, + ); + assert.ok( + received > 0, + `Channel ${channel} should receive messages during the failure scenario`, + ); + } + + await setTimeout(FAILURE_RECOVERY_WAIT_MS); + messageTracker.reset(); + + await withBackgroundPublishing( + publisher, + CHANNELS, + messageTracker, + async () => { + await waitForAssertion(() => { + for (const channel of CHANNELS) { + const { received } = getChannelStatsOrThrow( + messageTracker, + channel, + ); + + assert.ok( + received > 0, + `Channel ${channel} should resume receiving messages after recovery`, + ); + } + }, RECOVERY_ASSERTION_TIMEOUT_MS); + }, + ); + + messageTracker.reset(); + + await withBackgroundPublishing( + publisher, + CHANNELS, + messageTracker, + async () => { + await setTimeout(POST_RECOVERY_PUBLISH_DURATION_MS); + }, + ); + + for (const channel of CHANNELS) { + const { sent, received } = getChannelStatsOrThrow( + messageTracker, + channel, + ); + const deliveryRatio = received / sent; + + assert.ok( + sent > 0, + `Channel ${channel} should have sent messages`, + ); + assert.ok( + received > 0, + `Channel ${channel} should have received messages`, + ); + assert.ok( + deliveryRatio >= POST_RECOVERY_DELIVERY_RATIO, + `Channel ${channel} received ${received} of ${sent} messages after recovery (${( + deliveryRatio * 100 + ).toFixed(1)}%)`, + ); + } + } finally { + subscriber.destroy(); + } + }, + { testTimeout: TEST_TIMEOUT }, + ); + } + + testUtils.testWithRECluster( + "should NOT receive messages after sunsubscribe", + async (cluster) => { + const messageTracker = new MessageTracker(CHANNELS); + const publisher = cluster; + const subscriber = await createConnectedDuplicate(cluster); + + try { + for (const channel of CHANNELS) { + await subscriber.sSubscribe(channel, (_message, receivedChannel) => + messageTracker.incrementReceived(receivedChannel), + ); + } + + await withBackgroundPublishing( + publisher, + CHANNELS, + messageTracker, + async () => { + await setTimeout(UNSUBSCRIBE_VERIFICATION_DURATION_MS); + }, + ); - for (const channel of CHANNELS) { - assert.strictEqual( - messageTracker.getChannelStats(channel)?.received, - messageTracker.getChannelStats(channel)?.sent, - ); - } + for (const channel of CHANNELS) { + const { sent, received } = getChannelStatsOrThrow( + messageTracker, + channel, + ); - // Reset message tracker - messageTracker.reset(); + assert.strictEqual( + received, + sent, + `Channel ${channel} should receive every published message before unsubscribe`, + ); + } - const unsubscribeChannels = [ - CHANNELS_BY_SLOT["1000"], - CHANNELS_BY_SLOT["8000"], - CHANNELS_BY_SLOT["16000"], - ]; + messageTracker.reset(); - for (const channel of unsubscribeChannels) { - await subscriber.sUnsubscribe(channel); - } + const unsubscribeChannels = [ + CHANNELS_BY_SLOT["1000"], + CHANNELS_BY_SLOT["8000"], + CHANNELS_BY_SLOT["16000"], + ]; - const { - controller: afterUnsubscribeController, - result: afterUnsubscribeResult, - } = TestCommandRunner.publishMessagesUntilAbortSignal( - publisher, - CHANNELS, - messageTracker, - ); + for (const channel of unsubscribeChannels) { + await subscriber.sUnsubscribe(channel); + } - // Wait for 5 seconds, while publishing messages - await setTimeout(5_000); - afterUnsubscribeController.abort(); - await afterUnsubscribeResult; - - for (const channel of unsubscribeChannels) { - assert.strictEqual( - messageTracker.getChannelStats(channel)?.received, - 0, - `Channel ${channel} should not have received messages after unsubscribe`, - ); - } + await withBackgroundPublishing( + publisher, + CHANNELS, + messageTracker, + async () => { + await setTimeout(UNSUBSCRIBE_VERIFICATION_DURATION_MS); + }, + ); - // All other channels should have received messages - const stillSubscribedChannels = CHANNELS.filter( - (channel) => !unsubscribeChannels.includes(channel as any), - ); + for (const channel of unsubscribeChannels) { + assert.strictEqual( + getChannelStatsOrThrow(messageTracker, channel).received, + 0, + `Channel ${channel} should not receive messages after unsubscribe`, + ); + } + + const unsubscribedChannels = new Set(unsubscribeChannels); + const stillSubscribedChannels = CHANNELS.filter( + (channel) => !unsubscribedChannels.has(channel), + ); - for (const channel of stillSubscribedChannels) { - assert.ok( - messageTracker.getChannelStats(channel)!.received > 0, - `Channel ${channel} should have received messages`, - ); - } - }); + for (const channel of stillSubscribedChannels) { + assert.ok( + getChannelStatsOrThrow(messageTracker, channel).received > 0, + `Channel ${channel} should continue receiving messages`, + ); + } + } finally { + subscriber.destroy(); + } + }, + { testTimeout: TEST_TIMEOUT }, + ); }); describe("Multiple Subscribers", () => { - let subscriber1: Cluster; - let subscriber2: Cluster; - - let publisher: Cluster; - - let messageTracker1: MessageTracker; - let messageTracker2: MessageTracker; - - beforeEach(async () => { - messageTracker1 = new MessageTracker(CHANNELS); - messageTracker2 = new MessageTracker(CHANNELS); - subscriber1 = createClusterTestClient(config.clientConfig); - subscriber2 = createClusterTestClient(config.clientConfig); - publisher = createClusterTestClient(config.clientConfig); - await Promise.all([ - subscriber1.connect(), - subscriber2.connect(), - publisher.connect(), - ]); - }); - - afterEach(async () => { - await Promise.all([ - subscriber1.quit(), - subscriber2.quit(), - publisher.quit(), - ]); - }); - - it("should receive messages published to multiple channels", async () => { - for (const channel of CHANNELS) { - await subscriber1.sSubscribe(channel, (_msg, channel) => { messageTracker1.incrementReceived(channel); }); - await subscriber2.sSubscribe(channel, (_msg, channel) => { messageTracker2.incrementReceived(channel); }); - } - - const { controller, result } = - TestCommandRunner.publishMessagesUntilAbortSignal( - publisher, - CHANNELS, - messageTracker1, // Use messageTracker1 for all publishing - ); - - // Wait for 10 seconds, while publishing messages - await setTimeout(10_000); - controller.abort(); - await result; - - for (const channel of CHANNELS) { - assert.strictEqual( - messageTracker1.getChannelStats(channel)?.received, - messageTracker1.getChannelStats(channel)?.sent, - ); - assert.strictEqual( - messageTracker2.getChannelStats(channel)?.received, - messageTracker1.getChannelStats(channel)?.sent, - ); - } - }); - - it("should resume publishing and receiving after failover", async () => { - for (const channel of CHANNELS) { - await subscriber1.sSubscribe(channel, (_msg, channel) => { messageTracker1.incrementReceived(channel); }); - await subscriber2.sSubscribe(channel, (_msg, channel) => { messageTracker2.incrementReceived(channel); }); - } - - // Start publishing messages - const { controller: publishAbort, result: publishResult } = - TestCommandRunner.publishMessagesUntilAbortSignal( - publisher, - CHANNELS, - messageTracker1, // Use messageTracker1 for all publishing - ); - - // Trigger failover during publishing - const { action_id: failoverActionId } = - await faultInjectorClient.triggerAction({ - type: "failover", - parameters: { - bdb_id: config.clientConfig.bdbId.toString(), - cluster_index: 0, - }, - }); - - // Wait for failover to complete - await faultInjectorClient.waitForAction(failoverActionId); - - publishAbort.abort(); - await publishResult; - - for (const channel of CHANNELS) { - const sent = messageTracker1.getChannelStats(channel)!.sent; - const received1 = messageTracker1.getChannelStats(channel)!.received; - - const received2 = messageTracker2.getChannelStats(channel)!.received; - - assert.ok( - received1 <= sent, - `Channel ${channel}: received (${received1}) should be <= sent (${sent})`, - ); - assert.ok( - received2 <= sent, - `Channel ${channel}: received2 (${received2}) should be <= sent (${sent})`, - ); - } - - // Wait for 2 seconds before resuming publishing - await setTimeout(2_000); + testUtils.testWithRECluster( + "should receive messages published to multiple channels", + async (cluster) => { + const messageTracker1 = new MessageTracker(CHANNELS); + const messageTracker2 = new MessageTracker(CHANNELS); + const publisher = cluster; + const [subscriber1, subscriber2] = await Promise.all([ + createConnectedDuplicate(cluster), + createConnectedDuplicate(cluster), + ]); + + try { + for (const channel of CHANNELS) { + await subscriber1.sSubscribe(channel, (_message, receivedChannel) => { + messageTracker1.incrementReceived(receivedChannel); + }); + await subscriber2.sSubscribe(channel, (_message, receivedChannel) => { + messageTracker2.incrementReceived(receivedChannel); + }); + } + + await withBackgroundPublishing( + publisher, + CHANNELS, + messageTracker1, + async () => { + await setTimeout(PUBLISH_VERIFICATION_DURATION_MS); + }, + ); - messageTracker1.reset(); - messageTracker2.reset(); + for (const channel of CHANNELS) { + const { sent, received: received1 } = getChannelStatsOrThrow( + messageTracker1, + channel, + ); + const { received: received2 } = getChannelStatsOrThrow( + messageTracker2, + channel, + ); + + assert.strictEqual( + received1, + sent, + `Channel ${channel} should deliver every message to subscriber 1`, + ); + assert.strictEqual( + received2, + sent, + `Channel ${channel} should deliver every message to subscriber 2`, + ); + } + } finally { + subscriber1.destroy(); + subscriber2.destroy(); + } + }, + { testTimeout: TEST_TIMEOUT }, + ); + + testUtils.testWithRECluster( + "should resume publishing and receiving after failover", + async (cluster, faultInjectorClient) => { + const messageTracker1 = new MessageTracker(CHANNELS); + const messageTracker2 = new MessageTracker(CHANNELS); + const publisher = cluster; + const [subscriber1, subscriber2] = await Promise.all([ + createConnectedDuplicate(cluster), + createConnectedDuplicate(cluster), + ]); + + try { + for (const channel of CHANNELS) { + await subscriber1.sSubscribe(channel, (_message, receivedChannel) => { + messageTracker1.incrementReceived(receivedChannel); + }); + await subscriber2.sSubscribe(channel, (_message, receivedChannel) => { + messageTracker2.incrementReceived(receivedChannel); + }); + } + + await withBackgroundPublishing( + publisher, + CHANNELS, + messageTracker1, + async () => { + await setTimeout(FAILURE_PUBLISH_WARMUP_MS); + + await faultInjectorClient.triggerAction( + { + type: "failover", + parameters: { + cluster_index: CLUSTER_INDEX, + }, + }, + { + maxWaitTimeMs: FAILURE_ACTION_TIMEOUT_MS, + }, + ); + }, + ); - const { - controller: afterFailoverController, - result: afterFailoverResult, - } = TestCommandRunner.publishMessagesUntilAbortSignal( - publisher, - CHANNELS, - messageTracker1, - ); + for (const channel of CHANNELS) { + const sent = getChannelStatsOrThrow(messageTracker1, channel).sent; + const received1 = getChannelStatsOrThrow( + messageTracker1, + channel, + ).received; + const received2 = getChannelStatsOrThrow( + messageTracker2, + channel, + ).received; + + assert.ok( + received1 <= sent, + `Channel ${channel}: subscriber 1 received (${received1}) should be <= sent (${sent})`, + ); + assert.ok( + received2 <= sent, + `Channel ${channel}: subscriber 2 received (${received2}) should be <= sent (${sent})`, + ); + } + + await setTimeout(FAILURE_RECOVERY_WAIT_MS); + + messageTracker1.reset(); + messageTracker2.reset(); + + await withBackgroundPublishing( + publisher, + CHANNELS, + messageTracker1, + async () => { + await setTimeout(PUBLISH_VERIFICATION_DURATION_MS); + }, + ); - await setTimeout(10_000); - afterFailoverController.abort(); - await afterFailoverResult; - - for (const channel of CHANNELS) { - const sent = messageTracker1.getChannelStats(channel)!.sent; - const received1 = messageTracker1.getChannelStats(channel)!.received; - const received2 = messageTracker2.getChannelStats(channel)!.received; - assert.ok(sent > 0, `Channel ${channel} should have sent messages`); - assert.ok( - received1 > 0, - `Channel ${channel} should have received messages by subscriber 1`, - ); - assert.ok( - received2 > 0, - `Channel ${channel} should have received messages by subscriber 2`, - ); - assert.strictEqual( - received1, - sent, - `Channel ${channel} received (${received1}) should equal sent (${sent}) once resumed after failover by subscriber 1`, - ); - assert.strictEqual( - received2, - sent, - `Channel ${channel} received (${received2}) should equal sent (${sent}) once resumed after failover by subscriber 2`, - ); - } - }); + for (const channel of CHANNELS) { + const sent = getChannelStatsOrThrow(messageTracker1, channel).sent; + const received1 = getChannelStatsOrThrow( + messageTracker1, + channel, + ).received; + const received2 = getChannelStatsOrThrow( + messageTracker2, + channel, + ).received; + + assert.ok(sent > 0, `Channel ${channel} should have sent messages`); + assert.ok( + received1 > 0, + `Channel ${channel} should have received messages by subscriber 1`, + ); + assert.ok( + received2 > 0, + `Channel ${channel} should have received messages by subscriber 2`, + ); + assert.strictEqual( + received1, + sent, + `Channel ${channel} should fully recover for subscriber 1 after failover`, + ); + assert.strictEqual( + received2, + sent, + `Channel ${channel} should fully recover for subscriber 2 after failover`, + ); + } + } finally { + subscriber1.destroy(); + subscriber2.destroy(); + } + }, + { testTimeout: TEST_TIMEOUT }, + ); }); }); diff --git a/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/command-runner.ts b/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/command-runner.ts index 7b1a217bbfd..b7322b2c1b6 100644 --- a/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/command-runner.ts +++ b/packages/client/lib/tests/test-scenario/sharded-pubsub/utils/command-runner.ts @@ -1,7 +1,10 @@ import type { MessageTracker } from "./message-tracker"; -import { Cluster } from "./test.util"; import { setTimeout } from "timers/promises"; +interface ShardedPublisher { + sPublish(channel: string, message: string): Promise; +} + /** * Options for the `publishMessagesUntilAbortSignal` method */ @@ -43,7 +46,7 @@ export class TestCommandRunner { * An object containing the abort controller and a promise that resolves when publishing stops. */ static publishMessagesUntilAbortSignal( - client: Cluster, + client: ShardedPublisher, channels: string[], messageTracker: MessageTracker, options?: Partial,