diff --git a/src/bunit.generators.internal/Web.AngleSharp/IElementWrapper.cs b/src/bunit.generators.internal/Web.AngleSharp/IElementWrapper.cs index de2596bdc..92171031d 100644 --- a/src/bunit.generators.internal/Web.AngleSharp/IElementWrapper.cs +++ b/src/bunit.generators.internal/Web.AngleSharp/IElementWrapper.cs @@ -14,5 +14,10 @@ public interface IElementWrapper where TElement : class, IElement /// Gets the wrapped element. /// TElement WrappedElement { get; } + + /// + /// Gets the element wrapper factory used by this wrapper. + /// + IElementWrapperFactory Factory { get; } } #nullable restore diff --git a/src/bunit.generators.internal/Web.AngleSharp/WrapperBase.cs b/src/bunit.generators.internal/Web.AngleSharp/WrapperBase.cs index 8dd4d5963..aaba6d18a 100644 --- a/src/bunit.generators.internal/Web.AngleSharp/WrapperBase.cs +++ b/src/bunit.generators.internal/Web.AngleSharp/WrapperBase.cs @@ -33,6 +33,16 @@ public TElement WrappedElement } } + /// + /// Gets the element wrapper factory used by this wrapper. + /// + [DebuggerNonUserCode] + public IElementWrapperFactory Factory + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => elementFactory; + } + /// /// Creates an instance of the class. /// diff --git a/src/bunit.web.query/ByLabelTextElementFactory.cs b/src/bunit.web.query/ByLabelTextElementFactory.cs index 43a96f826..a8763755e 100644 --- a/src/bunit.web.query/ByLabelTextElementFactory.cs +++ b/src/bunit.web.query/ByLabelTextElementFactory.cs @@ -3,7 +3,7 @@ namespace Bunit; -internal sealed class ByLabelTextElementFactory : IElementWrapperFactory +internal sealed class ByLabelTextElementFactory : IElementWrapperFactory, IComponentAccessor { private readonly IRenderedComponent testTarget; private readonly string labelText; @@ -11,6 +11,8 @@ internal sealed class ByLabelTextElementFactory : IElementWrapperFactory public Action? OnElementReplaced { get; set; } + public IRenderedComponent Component => testTarget; + public ByLabelTextElementFactory(IRenderedComponent testTarget, string labelText, ByLabelTextOptions options) { this.testTarget = testTarget; diff --git a/src/bunit/Extensions/ElementExtensions.cs b/src/bunit/Extensions/ElementExtensions.cs new file mode 100644 index 000000000..aefaf1e22 --- /dev/null +++ b/src/bunit/Extensions/ElementExtensions.cs @@ -0,0 +1,44 @@ +using AngleSharp.Dom; +using Bunit.Rendering; + +namespace Bunit; + +/// +/// Provides extension methods for to find components rendered inside them. +/// +public static class ElementExtensions +{ + /// + /// Retrieves the first component of type that is rendered inside the specified . + /// + public static IRenderedComponent FindComponent(this IElement element) + where TComponent : IComponent + { + ArgumentNullException.ThrowIfNull(element); + + var componentAccessor = element.GetComponentAccessor(); + if (componentAccessor is null) + { + throw new InvalidOperationException( + $"Unable to find component of type {typeof(TComponent).Name} for the given element."); + } + + var component = componentAccessor.Component; + if (component.Instance is TComponent) + { + return (IRenderedComponent)component; + } + + var renderer = GetRendererFromComponent(component); + var foundComponent = renderer.FindComponentForElement(element); + if (foundComponent is not null) + { + return foundComponent; + } + + throw new InvalidOperationException($"Unable to find component of type {typeof(TComponent).Name} for the given element."); + } + + private static BunitRenderer GetRendererFromComponent(IRenderedComponent component) + => component.Services.GetRequiredService().Renderer; +} diff --git a/src/bunit/Extensions/Internal/CssSelectorElementFactory.cs b/src/bunit/Extensions/Internal/CssSelectorElementFactory.cs index 4c0f8d5e9..7ae58ba30 100644 --- a/src/bunit/Extensions/Internal/CssSelectorElementFactory.cs +++ b/src/bunit/Extensions/Internal/CssSelectorElementFactory.cs @@ -3,13 +3,15 @@ namespace Bunit; -internal sealed class CssSelectorElementFactory : IElementWrapperFactory +internal sealed class CssSelectorElementFactory : IElementWrapperFactory, IComponentAccessor { private readonly IRenderedComponent testTarget; private readonly string cssSelector; public Action? OnElementReplaced { get; set; } + public IRenderedComponent Component => testTarget; + public CssSelectorElementFactory(IRenderedComponent testTarget, string cssSelector) { this.testTarget = testTarget; diff --git a/src/bunit/Extensions/Internal/ElementFactoryAccessor.cs b/src/bunit/Extensions/Internal/ElementFactoryAccessor.cs new file mode 100644 index 000000000..f1794669a --- /dev/null +++ b/src/bunit/Extensions/Internal/ElementFactoryAccessor.cs @@ -0,0 +1,26 @@ +using AngleSharp.Dom; +using Bunit.Web.AngleSharp; + +namespace Bunit; + +/// +/// Internal extensions for working with wrapped elements. +/// +internal static class ElementFactoryAccessor +{ + /// + /// Attempts to get the component accessor from an element factory. + /// + /// The element to get the component from. + /// The component accessor if available, null otherwise. + internal static IComponentAccessor? GetComponentAccessor(this IElement element) + { + var factory = element.GetElementFactory(); + return factory as IComponentAccessor; + } + + private static IElementWrapperFactory? GetElementFactory(this IElement element) + { + return element is IElementWrapper wrapper ? wrapper.Factory : null; + } +} diff --git a/src/bunit/Extensions/Internal/IComponentAccessor.cs b/src/bunit/Extensions/Internal/IComponentAccessor.cs new file mode 100644 index 000000000..bc21c79f4 --- /dev/null +++ b/src/bunit/Extensions/Internal/IComponentAccessor.cs @@ -0,0 +1,12 @@ +namespace Bunit; + +/// +/// Interface for accessing the component that owns an element. +/// +internal interface IComponentAccessor +{ + /// + /// Gets the component that owns this element. + /// + IRenderedComponent Component { get; } +} diff --git a/src/bunit/InternalsVisibleTo.cs b/src/bunit/InternalsVisibleTo.cs index a3c9cf4b0..6bdae613c 100644 --- a/src/bunit/InternalsVisibleTo.cs +++ b/src/bunit/InternalsVisibleTo.cs @@ -1 +1,2 @@ [assembly: System.Runtime.CompilerServices.InternalsVisibleTo("Bunit.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010001be6b1a2ca57b09b7040e2ab0993e515296ae22aef4031a4fe388a1336fe21f69c7e8610e9935de6ed18d94b5c98429f99ef62ce3d0af28a7088f856239368ea808ad4c448aa2a8075ed581f989f36ed0d0b8b1cfcaf1ff6a4506c8a99b7024b6eb56996d08e3c9c1cf5db59bff96fcc63ccad155ef7fc63aab6a69862437b6")] +[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("Bunit.Web.Query, PublicKey=002400000480000094000000060200000024000052534131000400000100010001be6b1a2ca57b09b7040e2ab0993e515296ae22aef4031a4fe388a1336fe21f69c7e8610e9935de6ed18d94b5c98429f99ef62ce3d0af28a7088f856239368ea808ad4c448aa2a8075ed581f989f36ed0d0b8b1cfcaf1ff6a4506c8a99b7024b6eb56996d08e3c9c1cf5db59bff96fcc63ccad155ef7fc63aab6a69862437b6")] diff --git a/src/bunit/Rendering/BunitRenderer.cs b/src/bunit/Rendering/BunitRenderer.cs index 646c008a9..5104bd532 100644 --- a/src/bunit/Rendering/BunitRenderer.cs +++ b/src/bunit/Rendering/BunitRenderer.cs @@ -3,6 +3,7 @@ using System.Reflection; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; +using AngleSharp.Dom; using Microsoft.Extensions.Logging; namespace Bunit.Rendering; @@ -26,6 +27,7 @@ public sealed class BunitRenderer : Renderer private readonly HashSet returnedRenderedComponentIds = new(); private readonly List rootComponents = new(); + private readonly Dictionary elementReferenceToComponentId = new(); private readonly ILogger logger; private bool disposed; private TaskCompletionSource unhandledExceptionTsc = new(TaskCreationOptions.RunContinuationsAsynchronously); @@ -453,6 +455,7 @@ protected override Task UpdateDisplayAsync(in RenderBatch renderBatch) var id = renderBatch.DisposedComponentIDs.Array[i]; disposedComponentIds.Add(id); returnedRenderedComponentIds.Remove(id); + RemoveElementReferencesForComponent(id); } for (var i = 0; i < renderBatch.UpdatedComponents.Count; i++) @@ -467,6 +470,8 @@ protected override Task UpdateDisplayAsync(in RenderBatch renderBatch) var componentState = GetComponentState(diff.ComponentId); var renderedComponent = (IRenderedComponent)componentState; + TrackElementReferencesForComponent(diff.ComponentId); + if (returnedRenderedComponentIds.Contains(diff.ComponentId)) { renderedComponent.UpdateState(hasRendered: true, isMarkupGenerationRequired: diff.Edits.Count > 0); @@ -519,6 +524,101 @@ static bool IsParentComponentAlreadyUpdated(int componentId, in RenderBatch rend } } + private void TrackElementReferencesForComponent(int componentId) + { + var frames = GetCurrentRenderTreeFrames(componentId); + TrackElementReferencesInFrames(frames, componentId); + } + + private void TrackElementReferencesInFrames(ArrayRange frames, int componentId) + { + for (var i = 0; i < frames.Count; i++) + { + ref var frame = ref frames.Array[i]; + + if (frame.FrameType == RenderTreeFrameType.ElementReferenceCapture) + { + var elementReferenceId = frame.ElementReferenceCaptureId; + if (elementReferenceId != null) + { + elementReferenceToComponentId[elementReferenceId] = componentId; + } + } + else if (frame.FrameType == RenderTreeFrameType.Component) + { + TrackElementReferencesForComponent(frame.ComponentId); + } + } + } + + private void RemoveElementReferencesForComponent(int componentId) + { + var keysToRemove = elementReferenceToComponentId + .Where(kvp => kvp.Value == componentId) + .Select(kvp => kvp.Key) + .ToList(); + + foreach (var key in keysToRemove) + { + elementReferenceToComponentId.Remove(key); + } + } + + internal IRenderedComponent? FindComponentForElement(IElement element) + where TComponent : IComponent + { + var elementReferenceId = element.GetAttribute("blazor:elementReference"); + if (elementReferenceId is not null && elementReferenceToComponentId.TryGetValue(elementReferenceId, out var componentId)) + { + return GetRenderedComponent(componentId); + } + + return FindComponentByElementContainment(element); + } + + private IRenderedComponent? FindComponentByElementContainment(IElement element) + where TComponent : IComponent + { + List renderedComponentIdsWhenStarted = [..returnedRenderedComponentIds]; + var components = new List>(returnedRenderedComponentIds.Count); + + foreach (var parentComponent in renderedComponentIdsWhenStarted.Select(GetRenderedComponent)) + { + components.AddRange(FindComponents(parentComponent)); + } + + return components.FirstOrDefault(component => ComponentContainsElement(component, element)); + } + + private static bool ComponentContainsElement(IRenderedComponent component, IElement element) + where TComponent : IComponent + { + foreach (var node in component.Nodes) + { + if (node is IElement nodeElement && nodeElement.Equals(element)) + { + return true; + } + if (IsDescendantOf(element, node)) + { + return true; + } + } + return false; + } + + private static bool IsDescendantOf(IElement element, INode potentialAncestor) + { + var current = element.Parent; + while (current is not null) + { + if (current == potentialAncestor) + return true; + current = current.Parent; + } + return false; + } + /// internal new ArrayRange GetCurrentRenderTreeFrames(int componentId) => base.GetCurrentRenderTreeFrames(componentId); diff --git a/tests/bunit.tests/Extensions/FindComponentTest.razor b/tests/bunit.tests/Extensions/FindComponentTest.razor new file mode 100644 index 000000000..3a7fdc807 --- /dev/null +++ b/tests/bunit.tests/Extensions/FindComponentTest.razor @@ -0,0 +1,166 @@ +@inherits BunitContext + +@code { + [Fact] + public void FindComponentByElementHoldsInstance() + { + var cut = Render(); + var button = cut.Find("button"); + + var component = button.FindComponent(); + + component.ShouldNotBeNull(); + } + + [Fact] + public void MarkupShouldNotBeModifiedWhenFindingAComponent() + { + var cut = Render(); + var button = cut.Find("button"); + + var component = button.FindComponent(); + + component.MarkupMatches(@); + } + + [Fact] + public void FindComponentByElementReturnsCorrectInstance() + { + var cut = Render(); + var button = cut.Find("button"); + button.Click(); + + var component = button.FindComponent(); + + component.ShouldNotBeNull(); + component.Instance.Count.ShouldBe(1); + } + + [Fact] + public void FindingInstanceBeforeActionStillReflectsState() + { + var cut = Render(); + var button = cut.Find("button"); + var component = button.FindComponent(); + + button.Click(); + + component.Instance.Count.ShouldBe(1); + } + + [Fact] + public void RetrievingNestedComponentFromParentComponentIsSameInstance() + { + var cut = Render(); + var button = cut.Find("button"); + button.Click(); + + var component = button.FindComponent(); + var componentFromInner = cut.FindComponent().Find("button").FindComponent(); + + component.Instance.ShouldBeSameAs(componentFromInner.Instance); + component.Instance.Count.ShouldBe(1); + } + + [Fact] + public void FindComponentWorksWithElementReferences() + { + var cut = Render(); + var button = cut.Find("button"); + + button.HasAttribute("blazor:elementReference").ShouldBeTrue(); + + var component = button.FindComponent(); + + component.ShouldNotBeNull(); + component.Instance.ShouldNotBeNull(); + component.Instance.GetButtonText().ShouldBe("Click me!"); + } + + [Fact] + public void CantFindComponentThatChildOfElement() + { + var cut = Render(); + var button = cut.Find("button"); + + Action act = () => button.FindComponent(); + + act.ShouldThrow(); + } + + private sealed class ButtonComponent : ComponentBase + { + protected override void BuildRenderTree(RenderTreeBuilder builder) + { + builder.OpenElement(0, "button"); + builder.AddContent(1, "button content"); + builder.CloseElement(); + } + } + + private sealed class CounterComponent : ComponentBase + { + public int Count { get; private set; } + + protected override void BuildRenderTree(RenderTreeBuilder builder) + { + builder.OpenElement(0, "button"); + builder.AddAttribute(1, "onclick", EventCallback.Factory.Create(this, Increment)); + builder.AddContent(2, $"Count: {Count}"); + builder.CloseElement(); + } + + public void Increment() => Count++; + } + + private sealed class ParentComponentThatHasCounter : ComponentBase + { + protected override void BuildRenderTree(RenderTreeBuilder builder) + { + builder.OpenComponent(0); + builder.CloseComponent(); + } + } + + private sealed class ComponentWithElementRef : ComponentBase + { + private ElementReference buttonRef; + private bool isInitialized = false; + + protected override void BuildRenderTree(RenderTreeBuilder builder) + { + builder.OpenElement(0, "div"); + builder.AddContent(1, "Component with element ref:"); + + builder.OpenElement(2, "button"); + builder.AddElementReferenceCapture(3, value => buttonRef = value); + builder.AddContent(4, "Click me!"); + builder.CloseElement(); + + builder.CloseElement(); + } + + protected override void OnAfterRender(bool firstRender) + { + if (firstRender) + { + isInitialized = true; + } + } + + public string GetButtonText() => isInitialized ? "Click me!" : "Not ready"; + } + + private sealed class ParentComponentWithButtonAndCounter : ComponentBase + { + protected override void BuildRenderTree(RenderTreeBuilder builder) + { + builder.OpenElement(1, "button"); + builder.AddContent(2, "Increment Counter"); + builder.CloseElement(); + + builder.OpenComponent(3); + builder.CloseComponent(); + } + } +}