diff --git a/Source/MQTTnet/MqttClient.cs b/Source/MQTTnet/MqttClient.cs index 8f278b857..805fb3a41 100644 --- a/Source/MQTTnet/MqttClient.cs +++ b/Source/MQTTnet/MqttClient.cs @@ -126,17 +126,8 @@ public async Task ConnectAsync(MqttClientOptions option _unexpectedDisconnectPacket = null; - if (cancellationToken.CanBeCanceled) - { - connectResult = await ConnectInternal(adapter, cancellationToken).ConfigureAwait(false); - } - else - { - // Fall back to the general timeout specified in the options if the user passed - // CancellationToken.None or similar. - using var timeout = new CancellationTokenSource(Options.Timeout); - connectResult = await ConnectInternal(adapter, timeout.Token).ConfigureAwait(false); - } + using var timeoutCts = CreateTimeoutCancellationTokenSource(cancellationToken); + connectResult = await ConnectInternal(adapter, timeoutCts.Token).ConfigureAwait(false); if (connectResult.ResultCode != MqttClientConnectResultCode.Success) { @@ -211,15 +202,8 @@ public async Task DisconnectAsync(MqttClientDisconnectOptions options, Cancellat // must be thrown to let the caller know that the disconnect was not clean. var disconnectPacket = MqttDisconnectPacketFactory.Create(options); - if (cancellationToken.CanBeCanceled) - { - await Send(disconnectPacket, cancellationToken).ConfigureAwait(false); - } - else - { - using var timeout = new CancellationTokenSource(Options.Timeout); - await Send(disconnectPacket, timeout.Token).ConfigureAwait(false); - } + using var timeoutCts = CreateTimeoutCancellationTokenSource(cancellationToken); + await Send(disconnectPacket, timeoutCts.Token).ConfigureAwait(false); } finally { @@ -234,15 +218,8 @@ public async Task PingAsync(CancellationToken cancellationToken = default) ThrowIfDisposed(); ThrowIfNotConnected(); - if (cancellationToken.CanBeCanceled) - { - await Request(MqttPingReqPacket.Instance, cancellationToken).ConfigureAwait(false); - } - else - { - using var timeout = new CancellationTokenSource(Options.Timeout); - await Request(MqttPingReqPacket.Instance, timeout.Token).ConfigureAwait(false); - } + using var timeoutCts = CreateTimeoutCancellationTokenSource(cancellationToken); + await Request(MqttPingReqPacket.Instance, timeoutCts.Token).ConfigureAwait(false); } public Task PublishAsync(MqttApplicationMessage applicationMessage, CancellationToken cancellationToken = default) @@ -322,16 +299,8 @@ public async Task SubscribeAsync(MqttClientSubscribeO var subscribePacket = MqttSubscribePacketFactory.Create(options); subscribePacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); - MqttSubAckPacket subAckPacket; - if (cancellationToken.CanBeCanceled) - { - subAckPacket = await Request(subscribePacket, cancellationToken).ConfigureAwait(false); - } - else - { - using var timeout = new CancellationTokenSource(Options.Timeout); - subAckPacket = await Request(subscribePacket, timeout.Token).ConfigureAwait(false); - } + using var timeoutCts = CreateTimeoutCancellationTokenSource(cancellationToken); + var subAckPacket = await Request(subscribePacket, timeoutCts.Token).ConfigureAwait(false); return MqttClientSubscribeResultFactory.Create(subscribePacket, subAckPacket); } @@ -356,16 +325,8 @@ public async Task UnsubscribeAsync(MqttClientUnsubs var unsubscribePacket = MqttUnsubscribePacketFactory.Create(options); unsubscribePacket.PacketIdentifier = _packetIdentifierProvider.GetNextPacketIdentifier(); - MqttUnsubAckPacket unsubAckPacket; - if (cancellationToken.CanBeCanceled) - { - unsubAckPacket = await Request(unsubscribePacket, cancellationToken).ConfigureAwait(false); - } - else - { - using var timeout = new CancellationTokenSource(Options.Timeout); - unsubAckPacket = await Request(unsubscribePacket, timeout.Token).ConfigureAwait(false); - } + using var timeoutCts = CreateTimeoutCancellationTokenSource(cancellationToken); + var unsubAckPacket = await Request(unsubscribePacket, timeoutCts.Token).ConfigureAwait(false); return MqttClientUnsubscribeResultFactory.Create(unsubscribePacket, unsubAckPacket); } @@ -491,6 +452,16 @@ void Cleanup() } } + CancellationTokenSource CreateTimeoutCancellationTokenSource(CancellationToken cancellationToken) + { + var timeoutCts = cancellationToken.CanBeCanceled + ? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken) + : new CancellationTokenSource(); + + timeoutCts.CancelAfter(Options.Timeout); + return timeoutCts; + } + MqttClientConnectionStatus CompareExchangeConnectionStatus(MqttClientConnectionStatus value, MqttClientConnectionStatus comparand) { return (MqttClientConnectionStatus)Interlocked.CompareExchange(ref _connectionStatus, (int)value, (int)comparand); @@ -1069,4 +1040,4 @@ async Task TrySendKeepAliveMessages(CancellationToken cancellationToken) _logger.Verbose("Stopped sending keep alive packets"); } } -} \ No newline at end of file +}