Skip to content
Open
116 changes: 116 additions & 0 deletions src/Build.UnitTests/BackEnd/TaskHostNodeKey_Tests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.Build.Internal;
using Shouldly;
using Xunit;

namespace Microsoft.Build.Engine.UnitTests.BackEnd
{
/// <summary>
/// Tests for TaskHostNodeKey record struct functionality.
/// </summary>
public class TaskHostNodeKey_Tests
{
[Fact]
public void TaskHostNodeKey_Equality_SameValues_AreEqual()
{
var key1 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);
var key2 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);

key1.ShouldBe(key2);
(key1 == key2).ShouldBeTrue();
key1.GetHashCode().ShouldBe(key2.GetHashCode());
}

[Fact]
public void TaskHostNodeKey_Equality_DifferentNodeId_AreNotEqual()
{
var key1 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);
var key2 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 2);

key1.ShouldNotBe(key2);
(key1 != key2).ShouldBeTrue();
}

[Fact]
public void TaskHostNodeKey_Equality_DifferentHandshakeOptions_AreNotEqual()
{
var key1 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);
var key2 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.X64, 1);

key1.ShouldNotBe(key2);
(key1 != key2).ShouldBeTrue();
}

[Fact]
public void TaskHostNodeKey_CanBeUsedAsDictionaryKey()
{
var dict = new System.Collections.Generic.Dictionary<TaskHostNodeKey, string>();
var key1 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);
var key2 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.X64, 2);

dict[key1] = "value1";
dict[key2] = "value2";

dict[key1].ShouldBe("value1");
dict[key2].ShouldBe("value2");

// Create a new key with same values as key1
var key1Copy = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1);
dict[key1Copy].ShouldBe("value1");
}

[Fact]
public void TaskHostNodeKey_LargeNodeId_Works()
{
// Test that we can use node IDs greater than 255 (the previous limit)
var key1 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 256);
var key2 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, 1000);
var key3 = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, int.MaxValue);

key1.NodeId.ShouldBe(256);
key2.NodeId.ShouldBe(1000);
key3.NodeId.ShouldBe(int.MaxValue);

// Ensure they are all different
key1.ShouldNotBe(key2);
key2.ShouldNotBe(key3);
key1.ShouldNotBe(key3);
}

[Fact]
public void TaskHostNodeKey_NegativeNodeId_Works()
{
// Traditional multi-proc builds use -1 for node ID
var key = new TaskHostNodeKey(HandshakeOptions.TaskHost | HandshakeOptions.NET, -1);

key.NodeId.ShouldBe(-1);
key.HandshakeOptions.ShouldBe(HandshakeOptions.TaskHost | HandshakeOptions.NET);
}

[Fact]
public void TaskHostNodeKey_AllHandshakeOptions_Work()
{
// Test various HandshakeOptions combinations
HandshakeOptions[] optionsList =
[
HandshakeOptions.None,
HandshakeOptions.TaskHost,
HandshakeOptions.TaskHost | HandshakeOptions.NET,
HandshakeOptions.TaskHost | HandshakeOptions.X64,
HandshakeOptions.TaskHost | HandshakeOptions.NET | HandshakeOptions.NodeReuse,
HandshakeOptions.TaskHost | HandshakeOptions.CLR2,
HandshakeOptions.TaskHost | HandshakeOptions.Arm64
];

foreach (var options in optionsList)
{
var key = new TaskHostNodeKey(options, 42);

key.HandshakeOptions.ShouldBe(options);
key.NodeId.ShouldBe(42);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,28 +81,43 @@ internal class NodeProviderOutOfProcTaskHost : NodeProviderOutOfProcBase, INodeP

/// <summary>
/// A mapping of all the task host nodes managed by this provider.
/// The key is a TaskHostNodeKey combining HandshakeOptions and scheduled node ID.
/// </summary>
private ConcurrentDictionary<int, NodeContext> _nodeContexts;
private ConcurrentDictionary<TaskHostNodeKey, NodeContext> _nodeContexts;

/// <summary>
/// Reverse mapping from communication node ID to TaskHostNodeKey.
/// Used for O(1) lookup when handling node termination from ShutdownAllNodes.
/// </summary>
private ConcurrentDictionary<int, TaskHostNodeKey> _nodeIdToNodeKey;

/// <summary>
/// A mapping of all of the INodePacketFactories wrapped by this provider.
/// Keyed by the communication node ID (NodeContext.NodeId) for O(1) packet routing.
/// Thread-safe to support parallel taskhost creation in /mt mode where multiple thread nodes
/// can simultaneously create their own taskhosts.
/// </summary>
private ConcurrentDictionary<int, INodePacketFactory> _nodeIdToPacketFactory;

/// <summary>
/// A mapping of all of the INodePacketHandlers wrapped by this provider.
/// Keyed by the communication node ID (NodeContext.NodeId) for O(1) packet routing.
/// Thread-safe to support parallel taskhost creation in /mt mode where multiple thread nodes
/// can simultaneously create their own taskhosts.
/// </summary>
private ConcurrentDictionary<int, INodePacketHandler> _nodeIdToPacketHandler;

/// <summary>
/// Keeps track of the set of nodes for which we have not yet received shutdown notification.
/// Keeps track of the set of node IDs for which we have not yet received shutdown notification.
/// </summary>
private HashSet<int> _activeNodes;

/// <summary>
/// Counter for generating unique communication node IDs.
/// Incremented atomically for each new node created.
/// </summary>
private int _nextNodeId;

/// <summary>
/// Packet factory we use if there's not already one associated with a particular context.
/// </summary>
Expand Down Expand Up @@ -169,12 +184,23 @@ public IList<NodeInfo> CreateNodes(int nextNodeId, INodePacketFactory packetFact

/// <summary>
/// Sends data to the specified node.
/// Note: For task hosts, use the overload that takes TaskHostNodeKey instead.
/// </summary>
/// <param name="nodeId">The node to which data shall be sent.</param>
/// <param name="packet">The packet to send.</param>
public void SendData(int nodeId, INodePacket packet)
{
ErrorUtilities.VerifyThrow(_nodeContexts.TryGetValue(nodeId, out NodeContext context), "Invalid host context specified: {0}.", nodeId);
throw new NotImplementedException("For task hosts, use the overload that takes TaskHostNodeKey.");
}

/// <summary>
/// Sends data to the specified task host node.
/// </summary>
/// <param name="nodeKey">The task host node key identifying the target node.</param>
/// <param name="packet">The packet to send.</param>
internal void SendData(TaskHostNodeKey nodeKey, INodePacket packet)
{
ErrorUtilities.VerifyThrow(_nodeContexts.TryGetValue(nodeKey, out NodeContext context), "Invalid host context specified: {0}.", nodeKey);

SendData(context, packet);
}
Expand Down Expand Up @@ -211,10 +237,12 @@ public void ShutdownAllNodes()
public void InitializeComponent(IBuildComponentHost host)
{
this.ComponentHost = host;
_nodeContexts = new ConcurrentDictionary<int, NodeContext>();
_nodeContexts = new ConcurrentDictionary<TaskHostNodeKey, NodeContext>();
_nodeIdToNodeKey = new ConcurrentDictionary<int, TaskHostNodeKey>();
_nodeIdToPacketFactory = new ConcurrentDictionary<int, INodePacketFactory>();
_nodeIdToPacketHandler = new ConcurrentDictionary<int, INodePacketHandler>();
_activeNodes = new HashSet<int>();
_activeNodes = [];
_nextNodeId = 0;

_noNodesActiveEvent = new ManualResetEvent(true);
_localPacketFactory = new NodePacketFactory();
Expand Down Expand Up @@ -569,17 +597,16 @@ private static string GetPathFromEnvironmentOrDefault(string environmentVariable
/// Make sure a node in the requested context exists.
/// </summary>
internal bool AcquireAndSetUpHost(
HandshakeOptions hostContext,
int taskHostNodeId,
TaskHostNodeKey nodeKey,
INodePacketFactory factory,
INodePacketHandler handler,
TaskHostConfiguration configuration,
in TaskHostParameters taskHostParameters)
{
bool nodeCreationSucceeded;
if (!_nodeContexts.ContainsKey(taskHostNodeId))
if (!_nodeContexts.ContainsKey(nodeKey))
{
nodeCreationSucceeded = CreateNode(hostContext, taskHostNodeId, factory, handler, configuration, taskHostParameters);
nodeCreationSucceeded = CreateNode(nodeKey, factory, handler, configuration, taskHostParameters);
}
else
{
Expand All @@ -589,9 +616,10 @@ internal bool AcquireAndSetUpHost(

if (nodeCreationSucceeded)
{
NodeContext context = _nodeContexts[taskHostNodeId];
_nodeIdToPacketFactory[taskHostNodeId] = factory;
_nodeIdToPacketHandler[taskHostNodeId] = handler;
NodeContext context = _nodeContexts[nodeKey];
// Map the transport ID directly to the handlers for O(1) packet routing
_nodeIdToPacketFactory[context.NodeId] = factory;
_nodeIdToPacketHandler[context.NodeId] = handler;

// Configure the node.
context.SendData(configuration);
Expand All @@ -604,25 +632,35 @@ internal bool AcquireAndSetUpHost(
/// <summary>
/// Expected to be called when TaskHostTask is done with host of the given context.
/// </summary>
internal void DisconnectFromHost(int nodeId)
internal void DisconnectFromHost(TaskHostNodeKey nodeKey)
{
bool successRemoveFactory = _nodeIdToPacketFactory.TryRemove(nodeId, out _);
bool successRemoveHandler = _nodeIdToPacketHandler.TryRemove(nodeId, out _);
ErrorUtilities.VerifyThrow(_nodeContexts.TryGetValue(nodeKey, out NodeContext context), "Node context not found for key: {0}. Was the node created?", nodeKey);

bool successRemoveFactory = _nodeIdToPacketFactory.TryRemove(context.NodeId, out _);
bool successRemoveHandler = _nodeIdToPacketHandler.TryRemove(context.NodeId, out _);

ErrorUtilities.VerifyThrow(successRemoveFactory && successRemoveHandler, "Why are we trying to disconnect from a context that we already disconnected from? Did we call DisconnectFromHost twice?");
}

/// <summary>
/// Instantiates a new MSBuild or MSBuildTaskHost process acting as a child node.
/// </summary>
internal bool CreateNode(HandshakeOptions hostContext, int taskHostNodeId, INodePacketFactory factory, INodePacketHandler handler, TaskHostConfiguration configuration, in TaskHostParameters taskHostParameters)
internal bool CreateNode(TaskHostNodeKey nodeKey, INodePacketFactory factory, INodePacketHandler handler, TaskHostConfiguration configuration, in TaskHostParameters taskHostParameters)
{
ErrorUtilities.VerifyThrowArgumentNull(factory);
ErrorUtilities.VerifyThrow(!_nodeIdToPacketFactory.ContainsKey(taskHostNodeId), "We should not already have a factory for this context! Did we forget to call DisconnectFromHost somewhere?");
ErrorUtilities.VerifyThrow(!_nodeContexts.ContainsKey(nodeKey), "We should not already have a node for this context! Did we forget to call DisconnectFromHost somewhere?");

HandshakeOptions hostContext = nodeKey.HandshakeOptions;

// If runtime host path is null it means we don't have MSBuild.dll path resolved and there is no need to include it in the command line arguments.
string commandLineArgsPlaceholder = "\"{0}\" /nologo /nodemode:2 /nodereuse:{1} /low:{2} ";

// Generate a unique node ID for communication purposes using atomic increment.
int communicationNodeId = Interlocked.Increment(ref _nextNodeId);

// Create callbacks that capture the TaskHostNodeKey
void OnNodeContextCreated(NodeContext context) => NodeContextCreated(context, nodeKey);

IList<NodeContext> nodeContexts;

// Handle .NET task host context
Expand All @@ -639,10 +677,10 @@ internal bool CreateNode(HandshakeOptions hostContext, int taskHostNodeId, INode
nodeContexts = GetNodes(
runtimeHostPath,
string.Format(commandLineArgsPlaceholder, Path.Combine(msbuildAssemblyPath, Constants.MSBuildAssemblyName), NodeReuseIsEnabled(hostContext), ComponentHost.BuildParameters.LowPriority),
taskHostNodeId,
communicationNodeId,
this,
handshake,
NodeContextCreated,
OnNodeContextCreated,
NodeContextTerminated,
1);

Expand All @@ -663,10 +701,10 @@ internal bool CreateNode(HandshakeOptions hostContext, int taskHostNodeId, INode
nodeContexts = GetNodes(
msbuildLocation,
string.Format(commandLineArgsPlaceholder, string.Empty, NodeReuseIsEnabled(hostContext), ComponentHost.BuildParameters.LowPriority),
taskHostNodeId,
communicationNodeId,
this,
new Handshake(hostContext),
NodeContextCreated,
OnNodeContextCreated,
NodeContextTerminated,
1);

Expand All @@ -687,9 +725,10 @@ bool NodeReuseIsEnabled(HandshakeOptions hostContext)
/// <summary>
/// Method called when a context created.
/// </summary>
private void NodeContextCreated(NodeContext context)
private void NodeContextCreated(NodeContext context, TaskHostNodeKey nodeKey)
{
_nodeContexts[context.NodeId] = context;
_nodeContexts[nodeKey] = context;
_nodeIdToNodeKey[context.NodeId] = nodeKey;

// Start the asynchronous read.
context.BeginAsyncPacketRead();
Expand All @@ -702,19 +741,20 @@ private void NodeContextCreated(NodeContext context)
}

/// <summary>
/// Method called when a context terminates.
/// Method called when a context terminates (called from CreateNode callbacks or ShutdownAllNodes).
/// </summary>
private void NodeContextTerminated(int nodeId)
{
_nodeContexts.TryRemove(nodeId, out _);
// Remove from nodeKey-based lookup if we have it
if (_nodeIdToNodeKey.TryRemove(nodeId, out TaskHostNodeKey nodeKey))
{
_nodeContexts.TryRemove(nodeKey, out _);
}

// May also be removed by unnatural termination, so don't assume it's there
lock (_activeNodes)
{
if (_activeNodes.Contains(nodeId))
{
_activeNodes.Remove(nodeId);
}
_activeNodes.Remove(nodeId);

if (_activeNodes.Count == 0)
{
Expand Down
Loading
Loading