diff --git a/packages/assets-controllers/CHANGELOG.md b/packages/assets-controllers/CHANGELOG.md index 40415f1b2c4..e8542e647e9 100644 --- a/packages/assets-controllers/CHANGELOG.md +++ b/packages/assets-controllers/CHANGELOG.md @@ -7,6 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- **BREAKING:** Modify DeFi position fetching behaviour ([#6944](https://github.com/MetaMask/core/pull/6944)) + - The fetch request to the API times out after 8 seconds and attempts a single retry + - Refresh only updates the selected evm address + - `KeyringController:unlock` no longer starts polling + - `AccountsController:accountAdded` no longer updates DeFi positions + - `AccountTreeController:selectedAccountGroupChange` updates DeFi positions for the selected address + - `TransactionController:transactionConfirmed` only updates DeFi positions if the transaction is for the selected address + ## [86.0.0] ### Changed diff --git a/packages/assets-controllers/src/DeFiPositionsController/DeFiPositionsController.test.ts b/packages/assets-controllers/src/DeFiPositionsController/DeFiPositionsController.test.ts index 302a28319f5..7d7c8d9101a 100644 --- a/packages/assets-controllers/src/DeFiPositionsController/DeFiPositionsController.test.ts +++ b/packages/assets-controllers/src/DeFiPositionsController/DeFiPositionsController.test.ts @@ -1,5 +1,5 @@ import { deriveStateFromMetadata } from '@metamask/base-controller'; -import { BtcAccountType } from '@metamask/keyring-api'; +import { BtcAccountType, EthAccountType } from '@metamask/keyring-api'; import { MOCK_ANY_NAMESPACE, Messenger, @@ -23,17 +23,21 @@ import type { TransactionMeta, } from '../../../transaction-controller/src/types'; -const OWNER_ACCOUNTS = [ +const GROUP_ACCOUNTS = [ createMockInternalAccount({ id: 'mock-id-1', address: '0x0000000000000000000000000000000000000001', + type: EthAccountType.Eoa, }), createMockInternalAccount({ - id: 'mock-id-2', - address: '0x0000000000000000000000000000000000000002', + id: 'mock-id-btc-1', + type: BtcAccountType.P2wpkh, }), +]; + +const GROUP_ACCOUNTS_NO_EVM = [ createMockInternalAccount({ - id: 'mock-id-btc', + id: 'mock-id-btc-3', type: BtcAccountType.P2wpkh, }), ]; @@ -59,6 +63,7 @@ type RootMessenger = Messenger< * @param config.mockFetchPositions - The mock fetch positions function * @param config.mockGroupDeFiPositions - The mock group positions function * @param config.mockCalculateDefiMetrics - The mock calculate metrics function + * @param config.mockGroupAccounts - The mock group accounts function * @returns The controller instance, trigger functions, and spies */ function setupController({ @@ -67,21 +72,22 @@ function setupController({ mockFetchPositions = jest.fn(), mockGroupDeFiPositions = jest.fn(), mockCalculateDefiMetrics = jest.fn(), + mockGroupAccounts = GROUP_ACCOUNTS, }: { isEnabled?: () => boolean; mockFetchPositions?: jest.Mock; mockGroupDeFiPositions?: jest.Mock; mockCalculateDefiMetrics?: jest.Mock; mockTrackEvent?: jest.Mock; + mockGroupAccounts?: InternalAccount[]; } = {}) { const messenger: RootMessenger = new Messenger({ namespace: MOCK_ANY_NAMESPACE, }); - const mockListAccounts = jest.fn().mockReturnValue(OWNER_ACCOUNTS); messenger.registerActionHandler( - 'AccountsController:listAccounts', - mockListAccounts, + 'AccountTreeController:getAccountsFromSelectedAccountGroup', + () => mockGroupAccounts, ); const defiPositionControllerMessenger = new Messenger< @@ -95,12 +101,11 @@ function setupController({ }); messenger.delegate({ messenger: defiPositionControllerMessenger, - actions: ['AccountsController:listAccounts'], + actions: ['AccountTreeController:getAccountsFromSelectedAccountGroup'], events: [ - 'KeyringController:unlock', 'KeyringController:lock', 'TransactionController:transactionConfirmed', - 'AccountsController:accountAdded', + 'AccountTreeController:selectedAccountGroupChange', ], }); @@ -132,10 +137,6 @@ function setupController({ const updateSpy = jest.spyOn(controller, 'update' as never); - const triggerUnlock = (): void => { - messenger.publish('KeyringController:unlock'); - }; - const triggerLock = (): void => { messenger.publish('KeyringController:lock'); }; @@ -148,19 +149,19 @@ function setupController({ } as TransactionMeta); }; - const triggerAccountAdded = (account: Partial): void => { + const triggerAccountGroupChange = (): void => { messenger.publish( - 'AccountsController:accountAdded', - account as InternalAccount, + 'AccountTreeController:selectedAccountGroupChange', + 'entropy:test/0', + '', ); }; return { controller, - triggerUnlock, triggerLock, triggerTransactionConfirmed, - triggerAccountAdded, + triggerAccountGroupChange, buildPositionsFetcherSpy, updateSpy, mockFetchPositions, @@ -198,26 +199,8 @@ describe('DeFiPositionsController', () => { expect(stopAllPollingSpy).toHaveBeenCalled(); }); - it('starts polling if the keyring is unlocked', async () => { - const { controller, triggerUnlock } = setupController(); - const startPollingSpy = jest.spyOn(controller, 'startPolling'); - - triggerUnlock(); - - await flushPromises(); - - expect(startPollingSpy).toHaveBeenCalled(); - }); - - it('fetches positions for all accounts when polling', async () => { - const mockFetchPositions = jest.fn().mockImplementation((address) => { - // eslint-disable-next-line jest/no-conditional-in-test - if (OWNER_ACCOUNTS[0].address === address) { - return 'mock-fetch-data-1'; - } - - throw new Error('Error fetching positions'); - }); + it('fetches positions for the selected account when polling', async () => { + const mockFetchPositions = jest.fn().mockResolvedValue('mock-fetch-data-1'); const mockGroupDeFiPositions = jest .fn() .mockReturnValue('mock-grouped-data-1'); @@ -233,17 +216,15 @@ describe('DeFiPositionsController', () => { expect(controller.state).toStrictEqual({ allDeFiPositions: { - [OWNER_ACCOUNTS[0].address]: 'mock-grouped-data-1', - [OWNER_ACCOUNTS[1].address]: null, + [GROUP_ACCOUNTS[0].address]: 'mock-grouped-data-1', }, allDeFiPositionsCount: {}, }); expect(buildPositionsFetcherSpy).toHaveBeenCalled(); - expect(mockFetchPositions).toHaveBeenCalledWith(OWNER_ACCOUNTS[0].address); - expect(mockFetchPositions).toHaveBeenCalledWith(OWNER_ACCOUNTS[1].address); - expect(mockFetchPositions).toHaveBeenCalledTimes(2); + expect(mockFetchPositions).toHaveBeenCalledWith(GROUP_ACCOUNTS[0].address); + expect(mockFetchPositions).toHaveBeenCalledTimes(1); expect(mockGroupDeFiPositions).toHaveBeenCalledWith('mock-fetch-data-1'); expect(mockGroupDeFiPositions).toHaveBeenCalledTimes(1); @@ -293,19 +274,19 @@ describe('DeFiPositionsController', () => { mockGroupDeFiPositions, }); - triggerTransactionConfirmed(OWNER_ACCOUNTS[0].address); + triggerTransactionConfirmed(GROUP_ACCOUNTS[0].address); await flushPromises(); expect(controller.state).toStrictEqual({ allDeFiPositions: { - [OWNER_ACCOUNTS[0].address]: 'mock-grouped-data-1', + [GROUP_ACCOUNTS[0].address]: 'mock-grouped-data-1', }, allDeFiPositionsCount: {}, }); expect(buildPositionsFetcherSpy).toHaveBeenCalled(); - expect(mockFetchPositions).toHaveBeenCalledWith(OWNER_ACCOUNTS[0].address); + expect(mockFetchPositions).toHaveBeenCalledWith(GROUP_ACCOUNTS[0].address); expect(mockFetchPositions).toHaveBeenCalledTimes(1); expect(mockGroupDeFiPositions).toHaveBeenCalledWith('mock-fetch-data-1'); @@ -326,7 +307,33 @@ describe('DeFiPositionsController', () => { isEnabled: () => false, }); - triggerTransactionConfirmed(OWNER_ACCOUNTS[0].address); + triggerTransactionConfirmed(GROUP_ACCOUNTS[0].address); + await flushPromises(); + + expect(controller.state).toStrictEqual( + getDefaultDefiPositionsControllerState(), + ); + + expect(buildPositionsFetcherSpy).toHaveBeenCalled(); + + expect(mockFetchPositions).not.toHaveBeenCalled(); + + expect(mockGroupDeFiPositions).not.toHaveBeenCalled(); + + expect(updateSpy).not.toHaveBeenCalled(); + }); + + it('does not fetch positions for an account when a transaction is confirmed for a different than the selected account', async () => { + const { + controller, + triggerTransactionConfirmed, + buildPositionsFetcherSpy, + updateSpy, + mockFetchPositions, + mockGroupDeFiPositions, + } = setupController(); + + triggerTransactionConfirmed('0x0000000000000000000000000000000000000002'); await flushPromises(); expect(controller.state).toStrictEqual( @@ -342,7 +349,7 @@ describe('DeFiPositionsController', () => { expect(updateSpy).not.toHaveBeenCalled(); }); - it('fetches positions for an account when a new account is added', async () => { + it('fetches positions for the selected evm account when the account group changes', async () => { const mockFetchPositions = jest.fn().mockResolvedValue('mock-fetch-data-1'); const mockGroupDeFiPositions = jest .fn() @@ -350,7 +357,7 @@ describe('DeFiPositionsController', () => { const { controller, - triggerAccountAdded, + triggerAccountGroupChange, buildPositionsFetcherSpy, updateSpy, } = setupController({ @@ -358,23 +365,19 @@ describe('DeFiPositionsController', () => { mockGroupDeFiPositions, }); - const newAccountAddress = '0x0000000000000000000000000000000000000003'; - triggerAccountAdded({ - type: 'eip155:eoa', - address: newAccountAddress, - }); + triggerAccountGroupChange(); await flushPromises(); expect(controller.state).toStrictEqual({ allDeFiPositions: { - [newAccountAddress]: 'mock-grouped-data-1', + [GROUP_ACCOUNTS[0].address]: 'mock-grouped-data-1', }, allDeFiPositionsCount: {}, }); expect(buildPositionsFetcherSpy).toHaveBeenCalled(); - expect(mockFetchPositions).toHaveBeenCalledWith(newAccountAddress); + expect(mockFetchPositions).toHaveBeenCalledWith(GROUP_ACCOUNTS[0].address); expect(mockFetchPositions).toHaveBeenCalledTimes(1); expect(mockGroupDeFiPositions).toHaveBeenCalledWith('mock-fetch-data-1'); @@ -383,22 +386,19 @@ describe('DeFiPositionsController', () => { expect(updateSpy).toHaveBeenCalledTimes(1); }); - it('does not fetch positions for an account when a new account is added and the controller is disabled', async () => { + it('does not fetch positions when the account group changes and there is no evm account', async () => { const { controller, - triggerAccountAdded, + triggerAccountGroupChange, buildPositionsFetcherSpy, updateSpy, mockFetchPositions, mockGroupDeFiPositions, } = setupController({ - isEnabled: () => false, + mockGroupAccounts: GROUP_ACCOUNTS_NO_EVM, }); - triggerAccountAdded({ - type: 'eip155:eoa', - address: '0x0000000000000000000000000000000000000003', - }); + triggerAccountGroupChange(); await flushPromises(); expect(controller.state).toStrictEqual( @@ -430,19 +430,7 @@ describe('DeFiPositionsController', () => { }, }; - const mockMetric2 = { - event: 'mock-event', - category: 'mock-category', - properties: { - totalPositions: 2, - totalMarketValueUSD: 2, - }, - }; - - const mockCalculateDefiMetrics = jest - .fn() - .mockReturnValueOnce(mockMetric1) - .mockReturnValueOnce(mockMetric2); + const mockCalculateDefiMetrics = jest.fn().mockReturnValueOnce(mockMetric1); const { controller } = setupController({ mockGroupDeFiPositions, @@ -454,19 +442,15 @@ describe('DeFiPositionsController', () => { expect(mockCalculateDefiMetrics).toHaveBeenCalled(); expect(mockCalculateDefiMetrics).toHaveBeenCalledWith( - controller.state.allDeFiPositions[OWNER_ACCOUNTS[0].address], + controller.state.allDeFiPositions[GROUP_ACCOUNTS[0].address], ); expect(controller.state.allDeFiPositionsCount).toStrictEqual({ - [OWNER_ACCOUNTS[0].address]: mockMetric1.properties.totalPositions, - [OWNER_ACCOUNTS[1].address]: mockMetric2.properties.totalPositions, + [GROUP_ACCOUNTS[0].address]: mockMetric1.properties.totalPositions, }); - expect(mockTrackEvent).toHaveBeenNthCalledWith(1, mockMetric1); - expect(mockTrackEvent).toHaveBeenNthCalledWith(2, mockMetric2); - expect(mockTrackEvent).toHaveBeenCalledTimes(2); - expect(mockTrackEvent).toHaveBeenNthCalledWith(1, mockMetric1); - expect(mockTrackEvent).toHaveBeenNthCalledWith(2, mockMetric2); + expect(mockTrackEvent).toHaveBeenCalledWith(mockMetric1); + expect(mockTrackEvent).toHaveBeenCalledTimes(1); }); it('only calls track metric when position count changes', async () => { @@ -484,20 +468,10 @@ describe('DeFiPositionsController', () => { }, }; - const mockMetric2 = { - event: 'mock-event', - category: 'mock-category', - properties: { - totalPositions: 2, - totalMarketValueUSD: 2, - }, - }; - const mockCalculateDefiMetrics = jest .fn() .mockReturnValueOnce(mockMetric1) - .mockReturnValueOnce(mockMetric2) - .mockReturnValueOnce(mockMetric2); + .mockReturnValueOnce(mockMetric1); const { controller, triggerTransactionConfirmed } = setupController({ mockGroupDeFiPositions, @@ -505,23 +479,21 @@ describe('DeFiPositionsController', () => { mockTrackEvent, }); - triggerTransactionConfirmed(OWNER_ACCOUNTS[0].address); - triggerTransactionConfirmed(OWNER_ACCOUNTS[0].address); - triggerTransactionConfirmed(OWNER_ACCOUNTS[0].address); + triggerTransactionConfirmed(GROUP_ACCOUNTS[0].address); + triggerTransactionConfirmed(GROUP_ACCOUNTS[0].address); await flushPromises(); expect(mockCalculateDefiMetrics).toHaveBeenCalled(); expect(mockCalculateDefiMetrics).toHaveBeenCalledWith( - controller.state.allDeFiPositions[OWNER_ACCOUNTS[0].address], + controller.state.allDeFiPositions[GROUP_ACCOUNTS[0].address], ); expect(controller.state.allDeFiPositionsCount).toStrictEqual({ - [OWNER_ACCOUNTS[0].address]: mockMetric2.properties.totalPositions, + [GROUP_ACCOUNTS[0].address]: mockMetric1.properties.totalPositions, }); - expect(mockTrackEvent).toHaveBeenCalledTimes(2); - expect(mockTrackEvent).toHaveBeenNthCalledWith(1, mockMetric1); - expect(mockTrackEvent).toHaveBeenNthCalledWith(2, mockMetric2); + expect(mockTrackEvent).toHaveBeenCalledTimes(1); + expect(mockTrackEvent).toHaveBeenCalledWith(mockMetric1); }); describe('metadata', () => { diff --git a/packages/assets-controllers/src/DeFiPositionsController/DeFiPositionsController.ts b/packages/assets-controllers/src/DeFiPositionsController/DeFiPositionsController.ts index 06db940f5ca..b317414ee4f 100644 --- a/packages/assets-controllers/src/DeFiPositionsController/DeFiPositionsController.ts +++ b/packages/assets-controllers/src/DeFiPositionsController/DeFiPositionsController.ts @@ -1,14 +1,15 @@ import type { - AccountsControllerAccountAddedEvent, - AccountsControllerListAccountsAction, -} from '@metamask/accounts-controller'; + AccountTreeControllerGetAccountsFromSelectedAccountGroupAction, + AccountTreeControllerSelectedAccountGroupChangeEvent, +} from '@metamask/account-tree-controller'; import type { ControllerGetStateAction, ControllerStateChangeEvent, StateMetadata, } from '@metamask/base-controller'; -import type { KeyringControllerUnlockEvent } from '@metamask/keyring-controller'; +import { isEvmAccountType } from '@metamask/keyring-api'; import type { KeyringControllerLockEvent } from '@metamask/keyring-controller'; +import type { InternalAccount } from '@metamask/keyring-internal-api'; import type { Messenger } from '@metamask/messenger'; import { StaticIntervalPollingController } from '@metamask/polling-controller'; import type { TransactionControllerTransactionConfirmedEvent } from '@metamask/transaction-controller'; @@ -21,12 +22,9 @@ import { groupDeFiPositions, type GroupedDeFiPositions, } from './group-defi-positions'; -import { reduceInBatchesSerially } from '../assetsUtil'; const TEN_MINUTES_IN_MS = 600_000; -const FETCH_POSITIONS_BATCH_SIZE = 10; - const controllerName = 'DeFiPositionsController'; export type GroupedDeFiPositionsPerChain = { @@ -109,16 +107,16 @@ export type DeFiPositionsControllerStateChangeEvent = /** * The external actions available to the {@link DeFiPositionsController}. */ -export type AllowedActions = AccountsControllerListAccountsAction; +export type AllowedActions = + AccountTreeControllerGetAccountsFromSelectedAccountGroupAction; /** * The external events available to the {@link DeFiPositionsController}. */ export type AllowedEvents = - | KeyringControllerUnlockEvent | KeyringControllerLockEvent | TransactionControllerTransactionConfirmedEvent - | AccountsControllerAccountAddedEvent; + | AccountTreeControllerSelectedAccountGroupChangeEvent; /** * The messenger of the {@link DeFiPositionsController}. @@ -174,10 +172,6 @@ export class DeFiPositionsController extends StaticIntervalPollingController()< this.#fetchPositions = buildPositionFetcher(); this.#isEnabled = isEnabled; - this.messenger.subscribe('KeyringController:unlock', () => { - this.startPolling(null); - }); - this.messenger.subscribe('KeyringController:lock', () => { this.stopAllPolling(); }); @@ -185,22 +179,30 @@ export class DeFiPositionsController extends StaticIntervalPollingController()< this.messenger.subscribe( 'TransactionController:transactionConfirmed', async (transactionMeta) => { - if (!this.#isEnabled()) { + const selectedAddress = this.#getSelectedEvmAdress(); + + if ( + !selectedAddress || + selectedAddress.toLowerCase() !== + transactionMeta.txParams.from.toLowerCase() + ) { return; } - await this.#updateAccountPositions(transactionMeta.txParams.from); + await this.#updateAccountPositions(selectedAddress); }, ); this.messenger.subscribe( - 'AccountsController:accountAdded', - async (account) => { - if (!this.#isEnabled() || !account.type.startsWith('eip155:')) { + 'AccountTreeController:selectedAccountGroupChange', + async () => { + const selectedAddress = this.#getSelectedEvmAdress(); + + if (!selectedAddress) { return; } - await this.#updateAccountPositions(account.address); + await this.#updateAccountPositions(selectedAddress); }, ); @@ -212,57 +214,24 @@ export class DeFiPositionsController extends StaticIntervalPollingController()< return; } - const accounts = this.messenger.call('AccountsController:listAccounts'); - - const initialResult: { - accountAddress: string; - positions: GroupedDeFiPositionsPerChain | null; - }[] = []; - - const results = await reduceInBatchesSerially({ - initialResult, - values: accounts, - batchSize: FETCH_POSITIONS_BATCH_SIZE, - eachBatch: async (workingResult, batch) => { - const batchResults = ( - await Promise.all( - batch.map(async ({ address: accountAddress, type }) => { - if (type.startsWith('eip155:')) { - const positions = - await this.#fetchAccountPositions(accountAddress); - - return { - accountAddress, - positions, - }; - } - - return undefined; - }), - ) - ).filter(Boolean) as { - accountAddress: string; - positions: GroupedDeFiPositionsPerChain | null; - }[]; - - return [...workingResult, ...batchResults]; - }, - }); + const selectedAddress = this.#getSelectedEvmAdress(); - const allDefiPositions = results.reduce( - (acc, { accountAddress, positions }) => { - acc[accountAddress] = positions; - return acc; - }, - {} as DeFiPositionsControllerState['allDeFiPositions'], - ); + if (!selectedAddress) { + return; + } + + const accountPositions = await this.#fetchAccountPositions(selectedAddress); this.update((state) => { - state.allDeFiPositions = allDefiPositions; + state.allDeFiPositions[selectedAddress] = accountPositions; }); } async #updateAccountPositions(accountAddress: string): Promise { + if (!this.#isEnabled()) { + return; + } + const accountPositionsPerChain = await this.#fetchAccountPositions(accountAddress); @@ -314,4 +283,11 @@ export class DeFiPositionsController extends StaticIntervalPollingController()< this.#trackEvent?.(defiMetrics); } } + + #getSelectedEvmAdress(): string | undefined { + return this.messenger + .call('AccountTreeController:getAccountsFromSelectedAccountGroup') + .find((account: InternalAccount) => isEvmAccountType(account.type)) + ?.address; + } } diff --git a/packages/assets-controllers/src/DeFiPositionsController/fetch-positions.ts b/packages/assets-controllers/src/DeFiPositionsController/fetch-positions.ts index cd05d1921c8..427ca587081 100644 --- a/packages/assets-controllers/src/DeFiPositionsController/fetch-positions.ts +++ b/packages/assets-controllers/src/DeFiPositionsController/fetch-positions.ts @@ -1,3 +1,5 @@ +import { timeoutWithRetry } from '../utils/timeout-with-retry'; + export type DefiPositionResponse = AdapterResponse<{ tokens: ProtocolToken[]; }>; @@ -58,6 +60,9 @@ export type Balance = { // TODO: Update with prod API URL when available export const DEFI_POSITIONS_API_URL = 'https://defiadapters.api.cx.metamask.io'; +const EIGHT_SECONDS_IN_MS = 8_000; +const MAX_RETRIES = 1; + /** * Builds a function that fetches DeFi positions for a given account address * @@ -65,8 +70,10 @@ export const DEFI_POSITIONS_API_URL = 'https://defiadapters.api.cx.metamask.io'; */ export function buildPositionFetcher() { return async (accountAddress: string): Promise => { - const defiPositionsResponse = await fetch( - `${DEFI_POSITIONS_API_URL}/positions/${accountAddress}`, + const defiPositionsResponse = await timeoutWithRetry( + () => fetch(`${DEFI_POSITIONS_API_URL}/positions/${accountAddress}`), + EIGHT_SECONDS_IN_MS, + MAX_RETRIES, ); if (defiPositionsResponse.status !== 200) { diff --git a/packages/assets-controllers/src/utils/timeout-with-retry.test.ts b/packages/assets-controllers/src/utils/timeout-with-retry.test.ts new file mode 100644 index 00000000000..6b4ec9caa9b --- /dev/null +++ b/packages/assets-controllers/src/utils/timeout-with-retry.test.ts @@ -0,0 +1,109 @@ +import { timeoutWithRetry } from './timeout-with-retry'; +import { flushPromises } from '../../../../tests/helpers'; + +describe('timeoutWithRetry', () => { + const timeout = 1000; + + beforeEach(() => { + jest.useFakeTimers(); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('returns the result when call completes before timeout', async () => { + const mockCall = jest.fn(async () => 'success'); + + const resultPromise = timeoutWithRetry(mockCall, timeout, 0); + jest.runAllTimers(); + const result = await resultPromise; + + expect(result).toBe('success'); + expect(mockCall).toHaveBeenCalledTimes(1); + }); + + describe('retry behaviour', () => { + it('throws when maxRetries is negative', async () => { + const mockCall = jest.fn(async () => 'success'); + + await expect(() => + timeoutWithRetry(mockCall, timeout, -1), + ).rejects.toThrow('maxRetries must be greater than or equal to 0'); + }); + + it('returns the result when call completes just before timeout', async () => { + const mockCall = createMockCallWithRetries(timeout, 0); + + const resultPromise = timeoutWithRetry(mockCall, timeout, 0); + jest.runAllTimers(); + const result = await resultPromise; + + expect(result).toBe('success'); + expect(mockCall).toHaveBeenCalledTimes(1); + }); + + it('succeeds after multiple retries', async () => { + const mockCall = createMockCallWithRetries(timeout, 2); + + const resultPromise = timeoutWithRetry(mockCall, timeout, 3); + jest.runAllTimers(); + await flushPromises(); + jest.runAllTimers(); + const result = await resultPromise; + + expect(result).toBe('success'); + expect(mockCall).toHaveBeenCalledTimes(3); + }); + + it('throws when all retries are exhausted', async () => { + const mockCall = createMockCallWithRetries(timeout, 2); + + const resultPromise = timeoutWithRetry(mockCall, timeout, 1); + jest.runAllTimers(); + await flushPromises(); + jest.runAllTimers(); + + await expect(resultPromise).rejects.toThrow('timeout'); + expect(mockCall).toHaveBeenCalledTimes(2); + }); + }); + + describe('non-timeout errors', () => { + it('throws immediately on non-timeout error without retrying', async () => { + const customError = new Error('custom error'); + const mockCall = jest.fn(async () => { + throw customError; + }); + + const resultPromise = timeoutWithRetry(mockCall, timeout, 0); + jest.runAllTimers(); + + await expect(resultPromise).rejects.toThrow('custom error'); + expect(mockCall).toHaveBeenCalledTimes(1); + }); + }); +}); + +/** + * @param timeout - The timeout in milliseconds. + * @param timeoutsBeforeSuccess - The number of timeouts before the call succeeds. + * @returns A mock call function that times out for a specific number of times before returning 'success'. + */ +function createMockCallWithRetries( + timeout: number, + timeoutsBeforeSuccess: number, +) { + let callCount = 0; + const mockCall = jest.fn(async () => { + callCount += 1; + + if (callCount < timeoutsBeforeSuccess + 1) { + await new Promise((resolve) => setTimeout(resolve, timeout + 1)); + } + + return 'success'; + }); + + return mockCall; +} diff --git a/packages/assets-controllers/src/utils/timeout-with-retry.ts b/packages/assets-controllers/src/utils/timeout-with-retry.ts new file mode 100644 index 00000000000..90e0be1b964 --- /dev/null +++ b/packages/assets-controllers/src/utils/timeout-with-retry.ts @@ -0,0 +1,39 @@ +import { assert } from '@metamask/utils'; + +const TIMEOUT_ERROR = new Error('timeout'); + +/** + * + * @param call - The async function to call. + * @param timeout - Timeout in milliseconds for each call attempt. + * @param maxRetries - Maximum number of retries on timeout. + * @returns The resolved value of the call, or throws the last error if not a timeout or retries exhausted. + */ +// eslint-disable-next-line consistent-return +export async function timeoutWithRetry Promise>( + call: T, + timeout: number, + maxRetries: number, + // @ts-expect-error TS2366: Assertion guarantees loop executes +): Promise>> { + assert(maxRetries >= 0, 'maxRetries must be greater than or equal to 0'); + + let attempt = 0; + + while (attempt <= maxRetries) { + try { + return (await Promise.race([ + call(), + new Promise((_resolve, reject) => + setTimeout(() => reject(TIMEOUT_ERROR), timeout), + ), + ])) as Awaited>; + } catch (err) { + if (err === TIMEOUT_ERROR && attempt < maxRetries) { + attempt += 1; + continue; + } + throw err; + } + } +}