diff --git a/src/utils/mergeRefs/ko/mergeRefs.md b/src/utils/mergeRefs/ko/mergeRefs.md index 0cdefcbf..6aea8447 100644 --- a/src/utils/mergeRefs/ko/mergeRefs.md +++ b/src/utils/mergeRefs/ko/mergeRefs.md @@ -6,7 +6,7 @@ ```ts function mergeRefs( - ...refs: Array | RefCallback | null | undefined> + ...refs: Array | undefined> ): RefCallback; ``` @@ -15,7 +15,7 @@ function mergeRefs( diff --git a/src/utils/mergeRefs/mergeRefs.md b/src/utils/mergeRefs/mergeRefs.md index 35242c45..74f964a5 100644 --- a/src/utils/mergeRefs/mergeRefs.md +++ b/src/utils/mergeRefs/mergeRefs.md @@ -6,7 +6,7 @@ This function takes multiple refs (RefObject or RefCallback) and returns a singl ```ts function mergeRefs( - ...refs: Array | RefCallback | null | undefined> + ...refs: Array | undefined> ): RefCallback; ``` @@ -15,7 +15,7 @@ function mergeRefs( diff --git a/src/utils/mergeRefs/mergeRefs.spec.ts b/src/utils/mergeRefs/mergeRefs.spec.ts index f27cf243..e74e5637 100644 --- a/src/utils/mergeRefs/mergeRefs.spec.ts +++ b/src/utils/mergeRefs/mergeRefs.spec.ts @@ -1,6 +1,6 @@ import { useCallback, useRef } from 'react'; import { act } from '@testing-library/react'; -import { describe, expect, it } from 'vitest'; +import { describe, expect, it, vi } from 'vitest'; import { renderHookSSR } from '../../_internal/test-utils/renderHookSSR.tsx'; @@ -75,4 +75,54 @@ describe('mergeRefs', () => { expect(result.current.ref1.current).toBe(value); expect(ref3Value).toBe(value); }); + + it('should call cleanup functions returned by callback refs', () => { + const cleanupCalls: string[] = []; + + const callbackRef1 = vi.fn(() => { + return () => { + cleanupCalls.push('cleanup1'); + }; + }); + + const callbackRef2 = vi.fn(() => { + return () => { + cleanupCalls.push('cleanup2'); + }; + }); + + const mergedRef = mergeRefs(callbackRef1, callbackRef2); + const value = 'test-value'; + + act(() => { + mergedRef(value); + }); + + const cleanupFn = mergedRef(null); + if (cleanupFn) { + cleanupFn(); + } + + expect(cleanupCalls).toEqual(['cleanup1', 'cleanup2']); + expect(callbackRef1).toHaveBeenCalledWith(value); + expect(callbackRef2).toHaveBeenCalledWith(value); + }); + + it('verifies that object refs initialize correctly without cleanup functions', () => { + const refObj = { current: 'initial' }; + const mergedRef = mergeRefs(refObj); + + act(() => { + mergedRef('new-value'); + }); + expect(refObj.current).toBe('new-value'); + + const cleanupFn = mergedRef(null); + expect(cleanupFn).toBeInstanceOf(Function); + + if (cleanupFn) { + cleanupFn(); + } + expect(refObj.current).toBeNull(); + }); }); diff --git a/src/utils/mergeRefs/mergeRefs.ts b/src/utils/mergeRefs/mergeRefs.ts index b643558b..79bf65e5 100644 --- a/src/utils/mergeRefs/mergeRefs.ts +++ b/src/utils/mergeRefs/mergeRefs.ts @@ -1,4 +1,7 @@ -import { RefCallback, RefObject } from 'react'; +import { Ref, RefCallback } from 'react'; + +type StrictRef = NonNullable>; +type RefCleanup = ReturnType>; /** * @description @@ -7,7 +10,7 @@ import { RefCallback, RefObject } from 'react'; * * @template T - The type of target to be referenced. * - * @param {Array | RefCallback | null | undefined>} refs - An array of refs to be merged. Each ref can be either a RefObject or RefCallback. + * @param {Array | undefined>} refs - An array of refs to be merged. Each ref can be either a RefObject or RefCallback. * * @returns {RefCallback} A single ref callback that updates all provided refs. * @@ -34,19 +37,39 @@ import { RefCallback, RefObject } from 'react'; * return
; * } */ -export function mergeRefs(...refs: Array | RefCallback | null | undefined>): RefCallback { + +function assignRef(ref: StrictRef, value: T | null): RefCleanup { + if (typeof ref === 'function') { + return ref(value); + } + + ref.current = value; +} + +export function mergeRefs(...refs: Array | undefined>): RefCallback { + const availableRefs = refs.filter(ref => ref != null); + const cleanupMap = new Map, Exclude, void>>(); + return value => { - for (const ref of refs) { - if (ref == null) { - continue; + for (const ref of availableRefs) { + const cleanup = assignRef(ref, value); + if (cleanup) { + cleanupMap.set(ref, cleanup); } + } - if (typeof ref === 'function') { - ref(value); - continue; + return () => { + for (const ref of availableRefs) { + const cleanup = cleanupMap.get(ref); + if (cleanup && typeof cleanup === 'function') { + cleanup(); + continue; + } + + assignRef(ref, null); } - (ref as RefObject).current = value; - } + cleanupMap.clear(); + }; }; }