Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ public async Task<DispatchApplicationMessageResult> DispatchApplicationMessage(
continue;
}

if (await ShouldSkipEnqueue(senderId, session.Id, applicationMessage))
if (await _eventContainer.ShouldSkipEnqueue(senderId, session.Id, applicationMessage))
{
continue;
}
Expand Down Expand Up @@ -501,18 +501,20 @@ public async Task SubscribeAsync(string clientId, ICollection<MqttTopicFilter> t

var subscribeResult = await clientSession.Subscribe(fakeSubscribePacket, CancellationToken.None).ConfigureAwait(false);

if (subscribeResult.RetainedMessages != null)
if (subscribeResult.RetainedMessages == null)
{
foreach (var retainedMessageMatch in subscribeResult.RetainedMessages)
{
if (await ShouldSkipEnqueue(string.Empty, clientId, retainedMessageMatch.ApplicationMessage))
{
continue;
}
return;
}

var publishPacket = MqttPublishPacketFactory.Create(retainedMessageMatch);
clientSession.EnqueueDataPacket(new MqttPacketBusItem(publishPacket));
foreach (var retainedMessageMatch in subscribeResult.RetainedMessages)
{
if (await _eventContainer.ShouldSkipEnqueue(string.Empty, clientId, retainedMessageMatch.ApplicationMessage))
{
continue;
}

var publishPacket = MqttPublishPacketFactory.Create(retainedMessageMatch);
clientSession.EnqueueDataPacket(new MqttPacketBusItem(publishPacket));
}
}

Expand Down
15 changes: 11 additions & 4 deletions Source/MQTTnet.Server/Internal/MqttConnectedClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,20 @@ async Task HandleIncomingSubscribePacket(MqttSubscribePacket subscribePacket, Ca
return;
}

if (subscribeResult.RetainedMessages != null)
if (subscribeResult.RetainedMessages == null)
{
foreach (var retainedMessageMatch in subscribeResult.RetainedMessages)
return;
}

foreach (var retainedMessageMatch in subscribeResult.RetainedMessages)
{
if (await _eventContainer.ShouldSkipEnqueue(string.Empty, Id, retainedMessageMatch.ApplicationMessage))
{
var publishPacket = MqttPublishPacketFactory.Create(retainedMessageMatch);
Session.EnqueueDataPacket(new MqttPacketBusItem(publishPacket));
continue;
}

var publishPacket = MqttPublishPacketFactory.Create(retainedMessageMatch);
Session.EnqueueDataPacket(new MqttPacketBusItem(publishPacket));
}
}

Expand Down
59 changes: 36 additions & 23 deletions Source/MQTTnet.Server/Internal/MqttServerEventContainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,62 @@

namespace MQTTnet.Server.Internal;

public sealed class MqttServerEventContainer
public class MqttServerEventContainer
{
public AsyncEvent<ApplicationMessageNotConsumedEventArgs> ApplicationMessageNotConsumedEvent { get; } = new AsyncEvent<ApplicationMessageNotConsumedEventArgs>();
public AsyncEvent<ApplicationMessageNotConsumedEventArgs> ApplicationMessageNotConsumedEvent { get; } = new();

public AsyncEvent<ClientAcknowledgedPublishPacketEventArgs> ClientAcknowledgedPublishPacketEvent { get; } = new AsyncEvent<ClientAcknowledgedPublishPacketEventArgs>();
public AsyncEvent<ClientAcknowledgedPublishPacketEventArgs> ClientAcknowledgedPublishPacketEvent { get; } = new();

public AsyncEvent<ClientConnectedEventArgs> ClientConnectedEvent { get; } = new AsyncEvent<ClientConnectedEventArgs>();
public AsyncEvent<ClientConnectedEventArgs> ClientConnectedEvent { get; } = new();

public AsyncEvent<ClientDisconnectedEventArgs> ClientDisconnectedEvent { get; } = new AsyncEvent<ClientDisconnectedEventArgs>();
public AsyncEvent<ClientDisconnectedEventArgs> ClientDisconnectedEvent { get; } = new();

public AsyncEvent<ClientSubscribedTopicEventArgs> ClientSubscribedTopicEvent { get; } = new AsyncEvent<ClientSubscribedTopicEventArgs>();
public AsyncEvent<ClientSubscribedTopicEventArgs> ClientSubscribedTopicEvent { get; } = new();

public AsyncEvent<ClientUnsubscribedTopicEventArgs> ClientUnsubscribedTopicEvent { get; } = new AsyncEvent<ClientUnsubscribedTopicEventArgs>();
public AsyncEvent<ClientUnsubscribedTopicEventArgs> ClientUnsubscribedTopicEvent { get; } = new();

public AsyncEvent<InterceptingClientApplicationMessageEnqueueEventArgs> InterceptingClientEnqueueEvent { get; } = new AsyncEvent<InterceptingClientApplicationMessageEnqueueEventArgs>();
public AsyncEvent<InterceptingClientApplicationMessageEnqueueEventArgs> InterceptingClientEnqueueEvent { get; } = new();

public AsyncEvent<ApplicationMessageEnqueuedEventArgs> ApplicationMessageEnqueuedOrDroppedEvent { get; } = new AsyncEvent<ApplicationMessageEnqueuedEventArgs>();
public AsyncEvent<ApplicationMessageEnqueuedEventArgs> ApplicationMessageEnqueuedOrDroppedEvent { get; } = new();

public AsyncEvent<QueueMessageOverwrittenEventArgs> QueuedApplicationMessageOverwrittenEvent { get; } = new AsyncEvent<QueueMessageOverwrittenEventArgs>();
public AsyncEvent<QueueMessageOverwrittenEventArgs> QueuedApplicationMessageOverwrittenEvent { get; } = new();

public AsyncEvent<InterceptingPacketEventArgs> InterceptingInboundPacketEvent { get; } = new AsyncEvent<InterceptingPacketEventArgs>();
public AsyncEvent<InterceptingPacketEventArgs> InterceptingInboundPacketEvent { get; } = new();

public AsyncEvent<InterceptingPacketEventArgs> InterceptingOutboundPacketEvent { get; } = new AsyncEvent<InterceptingPacketEventArgs>();
public AsyncEvent<InterceptingPacketEventArgs> InterceptingOutboundPacketEvent { get; } = new();

public AsyncEvent<InterceptingPublishEventArgs> InterceptingPublishEvent { get; } = new AsyncEvent<InterceptingPublishEventArgs>();
public AsyncEvent<InterceptingPublishEventArgs> InterceptingPublishEvent { get; } = new();

public AsyncEvent<InterceptingSubscriptionEventArgs> InterceptingSubscriptionEvent { get; } = new AsyncEvent<InterceptingSubscriptionEventArgs>();
public AsyncEvent<InterceptingSubscriptionEventArgs> InterceptingSubscriptionEvent { get; } = new();

public AsyncEvent<InterceptingUnsubscriptionEventArgs> InterceptingUnsubscriptionEvent { get; } = new AsyncEvent<InterceptingUnsubscriptionEventArgs>();
public AsyncEvent<InterceptingUnsubscriptionEventArgs> InterceptingUnsubscriptionEvent { get; } = new();

public AsyncEvent<LoadingRetainedMessagesEventArgs> LoadingRetainedMessagesEvent { get; } = new AsyncEvent<LoadingRetainedMessagesEventArgs>();
public AsyncEvent<LoadingRetainedMessagesEventArgs> LoadingRetainedMessagesEvent { get; } = new();

public AsyncEvent<EventArgs> PreparingSessionEvent { get; } = new AsyncEvent<EventArgs>();
public AsyncEvent<EventArgs> PreparingSessionEvent { get; } = new();

public AsyncEvent<RetainedMessageChangedEventArgs> RetainedMessageChangedEvent { get; } = new AsyncEvent<RetainedMessageChangedEventArgs>();
public AsyncEvent<RetainedMessageChangedEventArgs> RetainedMessageChangedEvent { get; } = new();

public AsyncEvent<EventArgs> RetainedMessagesClearedEvent { get; } = new AsyncEvent<EventArgs>();
public AsyncEvent<EventArgs> RetainedMessagesClearedEvent { get; } = new();

public AsyncEvent<SessionDeletedEventArgs> SessionDeletedEvent { get; } = new AsyncEvent<SessionDeletedEventArgs>();
public AsyncEvent<SessionDeletedEventArgs> SessionDeletedEvent { get; } = new();

public AsyncEvent<EventArgs> StartedEvent { get; } = new AsyncEvent<EventArgs>();
public AsyncEvent<EventArgs> StartedEvent { get; } = new();

public AsyncEvent<EventArgs> StoppedEvent { get; } = new AsyncEvent<EventArgs>();
public AsyncEvent<EventArgs> StoppedEvent { get; } = new();

public AsyncEvent<ValidatingConnectionEventArgs> ValidatingConnectionEvent { get; } = new AsyncEvent<ValidatingConnectionEventArgs>();
public AsyncEvent<ValidatingConnectionEventArgs> ValidatingConnectionEvent { get; } = new();

public async Task<bool> ShouldSkipEnqueue(string senderClientId, string clientId, MqttApplicationMessage applicationMessage)
{
if (!InterceptingClientEnqueueEvent.HasHandlers)
{
return false;
}

var eventArgs = new InterceptingClientApplicationMessageEnqueueEventArgs(senderClientId, clientId, applicationMessage);
await InterceptingClientEnqueueEvent.InvokeAsync(eventArgs).ConfigureAwait(false);

return !eventArgs.AcceptEnqueue;
}
}
2 changes: 1 addition & 1 deletion Source/MQTTnet.Server/Status/MqttSessionStatus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public async Task<InjectMqttApplicationMessageResult> DeliverApplicationMessageA

await packetBusItem.WaitAsync().ConfigureAwait(false);

var injectResult = new InjectMqttApplicationMessageResult()
var injectResult = new InjectMqttApplicationMessageResult
{
PacketIdentifier = publishPacket.PacketIdentifier
};
Expand Down
5 changes: 5 additions & 0 deletions Source/MQTTnet.Tests/BaseTestClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,9 @@ protected Task LongTestDelay()
{
return Task.Delay(TimeSpan.FromSeconds(1));
}

protected Task MediumTestDelay()
{
return Task.Delay(TimeSpan.FromSeconds(0.5));
}
}
49 changes: 39 additions & 10 deletions Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace MQTTnet.Tests.Server;

// ReSharper disable InconsistentNaming
[TestClass]
public sealed class Retained_Messages_Tests : BaseTestClass
public class Retained_Messages_Tests : BaseTestClass
{
[TestMethod]
public async Task Clear_Retained_Message_With_Empty_Payload()
Expand All @@ -31,9 +31,9 @@ public async Task Clear_Retained_Message_With_Empty_Payload()
var c2 = await testEnvironment.ConnectClient();
var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2);

await Task.Delay(200);
await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce });
await Task.Delay(500);

await MediumTestDelay();

messageHandler.AssertReceivedCountEquals(0);
}
Expand All @@ -54,9 +54,9 @@ public async Task Clear_Retained_Message_With_Null_Payload()
var c2 = await testEnvironment.ConnectClient();
var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2);

await Task.Delay(200);
await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce });
await Task.Delay(500);

await MediumTestDelay();

messageHandler.AssertReceivedCountEquals(0);
}
Expand All @@ -82,7 +82,7 @@ await c1.PublishAsync(
var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2);
await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtLeastOnce });

await LongTestDelay();
await MediumTestDelay();

messageHandler.AssertReceivedCountEquals(1);

Expand Down Expand Up @@ -110,7 +110,7 @@ await c1.PublishAsync(
var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2);
await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce });

await LongTestDelay();
await MediumTestDelay();

messageHandler.AssertReceivedCountEquals(1);

Expand All @@ -131,7 +131,7 @@ public async Task Receive_No_Retained_Message_After_Subscribe()
var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2);
await c2.SubscribeAsync(new MqttTopicFilterBuilder().WithTopic("retained_other").Build());

await Task.Delay(500);
await MediumTestDelay();

messageHandler.AssertReceivedCountEquals(0);
}
Expand Down Expand Up @@ -178,10 +178,10 @@ await c1.PublishAsync(
var c2 = await testEnvironment.ConnectClient();
var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2);

await Task.Delay(200);
// Using QoS 2 will lead to 1 instead because the publish was made with QoS level 1 (see 3.8.4 SUBSCRIBE Actions)!
await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.ExactlyOnce });
await Task.Delay(500);

await MediumTestDelay();

messageHandler.AssertReceivedCountEquals(1);
}
Expand Down Expand Up @@ -248,4 +248,33 @@ public async Task Server_Reports_Retained_Messages_Supported_V5()

Assert.IsTrue(connectResult.RetainAvailable);
}

[TestMethod]
public async Task Skip_Enqueue_Of_Retained_Message()
{
using var testEnvironment = CreateTestEnvironment();

var server = await testEnvironment.StartServer();
server.InterceptingClientEnqueueAsync += e =>
{
e.AcceptEnqueue = false;

return Task.CompletedTask;
};

var c1 = await testEnvironment.ConnectClient();

await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().Build());

await c1.DisconnectAsync();

var c2 = await testEnvironment.ConnectClient();
var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2);

await c2.SubscribeAsync(new MqttTopicFilter { Topic = "retained", QualityOfServiceLevel = MqttQualityOfServiceLevel.AtMostOnce });

await MediumTestDelay();

messageHandler.AssertReceivedCountEquals(0);
}
}