diff --git a/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs b/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs index 88b46ec41..b4985bd2d 100644 --- a/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs +++ b/Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs @@ -189,7 +189,7 @@ public async Task DispatchApplicationMessage( continue; } - if (await ShouldSkipEnqueue(senderId, session.Id, applicationMessage)) + if (await _eventContainer.ShouldSkipEnqueue(senderId, session.Id, applicationMessage)) { continue; } @@ -501,18 +501,20 @@ public async Task SubscribeAsync(string clientId, ICollection t var subscribeResult = await clientSession.Subscribe(fakeSubscribePacket, CancellationToken.None).ConfigureAwait(false); - if (subscribeResult.RetainedMessages != null) + if (subscribeResult.RetainedMessages == null) { - foreach (var retainedMessageMatch in subscribeResult.RetainedMessages) - { - if (await ShouldSkipEnqueue(string.Empty, clientId, retainedMessageMatch.ApplicationMessage)) - { - continue; - } + return; + } - var publishPacket = MqttPublishPacketFactory.Create(retainedMessageMatch); - clientSession.EnqueueDataPacket(new MqttPacketBusItem(publishPacket)); + foreach (var retainedMessageMatch in subscribeResult.RetainedMessages) + { + if (await _eventContainer.ShouldSkipEnqueue(string.Empty, clientId, retainedMessageMatch.ApplicationMessage)) + { + continue; } + + var publishPacket = MqttPublishPacketFactory.Create(retainedMessageMatch); + clientSession.EnqueueDataPacket(new MqttPacketBusItem(publishPacket)); } } diff --git a/Source/MQTTnet.Server/Internal/MqttConnectedClient.cs b/Source/MQTTnet.Server/Internal/MqttConnectedClient.cs index bfdb21307..cb765f1eb 100644 --- a/Source/MQTTnet.Server/Internal/MqttConnectedClient.cs +++ b/Source/MQTTnet.Server/Internal/MqttConnectedClient.cs @@ -284,13 +284,20 @@ async Task HandleIncomingSubscribePacket(MqttSubscribePacket subscribePacket, Ca return; } - if (subscribeResult.RetainedMessages != null) + if (subscribeResult.RetainedMessages == null) { - foreach (var retainedMessageMatch in subscribeResult.RetainedMessages) + return; + } + + foreach (var retainedMessageMatch in subscribeResult.RetainedMessages) + { + if (await _eventContainer.ShouldSkipEnqueue(string.Empty, Id, retainedMessageMatch.ApplicationMessage)) { - var publishPacket = MqttPublishPacketFactory.Create(retainedMessageMatch); - Session.EnqueueDataPacket(new MqttPacketBusItem(publishPacket)); + continue; } + + var publishPacket = MqttPublishPacketFactory.Create(retainedMessageMatch); + Session.EnqueueDataPacket(new MqttPacketBusItem(publishPacket)); } } diff --git a/Source/MQTTnet.Server/Internal/MqttServerEventContainer.cs b/Source/MQTTnet.Server/Internal/MqttServerEventContainer.cs index 12aacad18..cd8c905dd 100644 --- a/Source/MQTTnet.Server/Internal/MqttServerEventContainer.cs +++ b/Source/MQTTnet.Server/Internal/MqttServerEventContainer.cs @@ -6,49 +6,62 @@ namespace MQTTnet.Server.Internal; -public sealed class MqttServerEventContainer +public class MqttServerEventContainer { - public AsyncEvent ApplicationMessageNotConsumedEvent { get; } = new AsyncEvent(); + public AsyncEvent ApplicationMessageNotConsumedEvent { get; } = new(); - public AsyncEvent ClientAcknowledgedPublishPacketEvent { get; } = new AsyncEvent(); + public AsyncEvent ClientAcknowledgedPublishPacketEvent { get; } = new(); - public AsyncEvent ClientConnectedEvent { get; } = new AsyncEvent(); + public AsyncEvent ClientConnectedEvent { get; } = new(); - public AsyncEvent ClientDisconnectedEvent { get; } = new AsyncEvent(); + public AsyncEvent ClientDisconnectedEvent { get; } = new(); - public AsyncEvent ClientSubscribedTopicEvent { get; } = new AsyncEvent(); + public AsyncEvent ClientSubscribedTopicEvent { get; } = new(); - public AsyncEvent ClientUnsubscribedTopicEvent { get; } = new AsyncEvent(); + public AsyncEvent ClientUnsubscribedTopicEvent { get; } = new(); - public AsyncEvent InterceptingClientEnqueueEvent { get; } = new AsyncEvent(); + public AsyncEvent InterceptingClientEnqueueEvent { get; } = new(); - public AsyncEvent ApplicationMessageEnqueuedOrDroppedEvent { get; } = new AsyncEvent(); + public AsyncEvent ApplicationMessageEnqueuedOrDroppedEvent { get; } = new(); - public AsyncEvent QueuedApplicationMessageOverwrittenEvent { get; } = new AsyncEvent(); + public AsyncEvent QueuedApplicationMessageOverwrittenEvent { get; } = new(); - public AsyncEvent InterceptingInboundPacketEvent { get; } = new AsyncEvent(); + public AsyncEvent InterceptingInboundPacketEvent { get; } = new(); - public AsyncEvent InterceptingOutboundPacketEvent { get; } = new AsyncEvent(); + public AsyncEvent InterceptingOutboundPacketEvent { get; } = new(); - public AsyncEvent InterceptingPublishEvent { get; } = new AsyncEvent(); + public AsyncEvent InterceptingPublishEvent { get; } = new(); - public AsyncEvent InterceptingSubscriptionEvent { get; } = new AsyncEvent(); + public AsyncEvent InterceptingSubscriptionEvent { get; } = new(); - public AsyncEvent InterceptingUnsubscriptionEvent { get; } = new AsyncEvent(); + public AsyncEvent InterceptingUnsubscriptionEvent { get; } = new(); - public AsyncEvent LoadingRetainedMessagesEvent { get; } = new AsyncEvent(); + public AsyncEvent LoadingRetainedMessagesEvent { get; } = new(); - public AsyncEvent PreparingSessionEvent { get; } = new AsyncEvent(); + public AsyncEvent PreparingSessionEvent { get; } = new(); - public AsyncEvent RetainedMessageChangedEvent { get; } = new AsyncEvent(); + public AsyncEvent RetainedMessageChangedEvent { get; } = new(); - public AsyncEvent RetainedMessagesClearedEvent { get; } = new AsyncEvent(); + public AsyncEvent RetainedMessagesClearedEvent { get; } = new(); - public AsyncEvent SessionDeletedEvent { get; } = new AsyncEvent(); + public AsyncEvent SessionDeletedEvent { get; } = new(); - public AsyncEvent StartedEvent { get; } = new AsyncEvent(); + public AsyncEvent StartedEvent { get; } = new(); - public AsyncEvent StoppedEvent { get; } = new AsyncEvent(); + public AsyncEvent StoppedEvent { get; } = new(); - public AsyncEvent ValidatingConnectionEvent { get; } = new AsyncEvent(); + public AsyncEvent ValidatingConnectionEvent { get; } = new(); + + public async Task ShouldSkipEnqueue(string senderClientId, string clientId, MqttApplicationMessage applicationMessage) + { + if (!InterceptingClientEnqueueEvent.HasHandlers) + { + return false; + } + + var eventArgs = new InterceptingClientApplicationMessageEnqueueEventArgs(senderClientId, clientId, applicationMessage); + await InterceptingClientEnqueueEvent.InvokeAsync(eventArgs).ConfigureAwait(false); + + return !eventArgs.AcceptEnqueue; + } } \ No newline at end of file diff --git a/Source/MQTTnet.Server/Status/MqttSessionStatus.cs b/Source/MQTTnet.Server/Status/MqttSessionStatus.cs index 780c08130..6d76ed590 100644 --- a/Source/MQTTnet.Server/Status/MqttSessionStatus.cs +++ b/Source/MQTTnet.Server/Status/MqttSessionStatus.cs @@ -58,7 +58,7 @@ public async Task DeliverApplicationMessageA await packetBusItem.WaitAsync().ConfigureAwait(false); - var injectResult = new InjectMqttApplicationMessageResult() + var injectResult = new InjectMqttApplicationMessageResult { PacketIdentifier = publishPacket.PacketIdentifier }; diff --git a/Source/MQTTnet.Tests/BaseTestClass.cs b/Source/MQTTnet.Tests/BaseTestClass.cs index aabef5ae6..b70881575 100644 --- a/Source/MQTTnet.Tests/BaseTestClass.cs +++ b/Source/MQTTnet.Tests/BaseTestClass.cs @@ -24,4 +24,9 @@ protected Task LongTestDelay() { return Task.Delay(TimeSpan.FromSeconds(1)); } + + protected Task MediumTestDelay() + { + return Task.Delay(TimeSpan.FromSeconds(0.5)); + } } \ No newline at end of file diff --git a/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs b/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs index 562b1134f..0354a7bb6 100644 --- a/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs @@ -13,7 +13,7 @@ namespace MQTTnet.Tests.Server; // ReSharper disable InconsistentNaming [TestClass] -public sealed class Retained_Messages_Tests : BaseTestClass +public class Retained_Messages_Tests : BaseTestClass { [TestMethod] public async Task Clear_Retained_Message_With_Empty_Payload() @@ -31,9 +31,9 @@ public async Task Clear_Retained_Message_With_Empty_Payload() var c2 = await testEnvironment.ConnectClient(); var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2); - await Task.Delay(200); await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); - await Task.Delay(500); + + await MediumTestDelay(); messageHandler.AssertReceivedCountEquals(0); } @@ -54,9 +54,9 @@ public async Task Clear_Retained_Message_With_Null_Payload() var c2 = await testEnvironment.ConnectClient(); var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2); - await Task.Delay(200); await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); - await Task.Delay(500); + + await MediumTestDelay(); messageHandler.AssertReceivedCountEquals(0); } @@ -82,7 +82,7 @@ await c1.PublishAsync( var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2); await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce }); - await LongTestDelay(); + await MediumTestDelay(); messageHandler.AssertReceivedCountEquals(1); @@ -110,7 +110,7 @@ await c1.PublishAsync( var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2); await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce }); - await LongTestDelay(); + await MediumTestDelay(); messageHandler.AssertReceivedCountEquals(1); @@ -131,7 +131,7 @@ public async Task Receive_No_Retained_Message_After_Subscribe() var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2); await c2.SubscribeAsync(new MqttTopicFilterBuilder().WithTopic("retained_other").Build()); - await Task.Delay(500); + await MediumTestDelay(); messageHandler.AssertReceivedCountEquals(0); } @@ -178,10 +178,10 @@ await c1.PublishAsync( var c2 = await testEnvironment.ConnectClient(); var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2); - await Task.Delay(200); // Using QoS 2 will lead to 1 instead because the publish was made with QoS level 1 (see 3.8.4 SUBSCRIBE Actions)! await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce }); - await Task.Delay(500); + + await MediumTestDelay(); messageHandler.AssertReceivedCountEquals(1); } @@ -248,4 +248,33 @@ public async Task Server_Reports_Retained_Messages_Supported_V5() Assert.IsTrue(connectResult.RetainAvailable); } + + [TestMethod] + public async Task Skip_Enqueue_Of_Retained_Message() + { + using var testEnvironment = CreateTestEnvironment(); + + var server = await testEnvironment.StartServer(); + server.InterceptingClientEnqueueAsync += e => + { + e.AcceptEnqueue = false; + + return Task.CompletedTask; + }; + + var c1 = await testEnvironment.ConnectClient(); + + await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build()); + + await c1.DisconnectAsync(); + + var c2 = await testEnvironment.ConnectClient(); + var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2); + + await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce }); + + await MediumTestDelay(); + + messageHandler.AssertReceivedCountEquals(0); + } } \ No newline at end of file