diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java index 3c845ce6d08..20bc69e81f0 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/impl/TlsChannelImpl.java @@ -159,7 +159,9 @@ public TlsChannelImpl( private final Lock readLock = new ReentrantLock(); private final Lock writeLock = new ReentrantLock(); - private volatile boolean negotiated = false; + private boolean handshakeStarted = false; + + private volatile boolean handshakeCompleted = false; /** * Whether a IOException was received from the underlying channel or from the {@link SSLEngine}. @@ -526,14 +528,28 @@ public void handshake() throws IOException { } private void doHandshake(boolean force) throws IOException, EofException { - if (!force && negotiated) return; + if (!force && handshakeCompleted) { + return; + } initLock.lock(); try { if (invalid || shutdownSent) throw new ClosedChannelException(); - if (force || !negotiated) { - engine.beginHandshake(); - LOGGER.trace("Called engine.beginHandshake()"); + if (force || !handshakeCompleted) { + + if (!handshakeStarted) { + engine.beginHandshake(); + LOGGER.trace("Called engine.beginHandshake()"); + + // Some engines that do not support renegotiations may be sensitive to calling + // SSLEngine.beginHandshake() more than once. This guard prevents that. + // See: https://github.com/marianobarrios/tls-channel/issues/197 + handshakeStarted = true; + } + handshake(Optional.empty(), Optional.empty()); + + handshakeCompleted = true; + // call client code try { initSessionCallback.accept(engine.getSession()); @@ -541,7 +557,6 @@ private void doHandshake(boolean force) throws IOException, EofException { LOGGER.trace("client code threw exception in session initialization callback", e); throw new TlsChannelCallbackException("session initialization callback failed", e); } - negotiated = true; } } finally { initLock.unlock(); diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java index 3f80fcddfa3..3af1eaa33e1 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java @@ -16,12 +16,17 @@ package com.mongodb.internal.connection; +import com.mongodb.ClusterFixture; import com.mongodb.MongoSocketOpenException; import com.mongodb.ServerAddress; import com.mongodb.connection.SocketSettings; import com.mongodb.connection.SslSettings; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.TimeoutSettings; +import org.bson.ByteBuf; +import org.bson.ByteBufNIO; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.MockedStatic; @@ -29,23 +34,34 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; import java.io.IOException; import java.net.ServerSocket; +import java.nio.ByteBuffer; import java.nio.channels.InterruptedByTimeoutException; import java.nio.channels.SocketChannel; +import java.util.Collections; import java.util.concurrent.TimeUnit; +import static com.mongodb.ClusterFixture.getPrimaryServerDescription; import static com.mongodb.internal.connection.OperationContext.simpleOperationContext; import static java.lang.String.format; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; class TlsChannelStreamFunctionalTest { private static final SslSettings SSL_SETTINGS = SslSettings.builder().enabled(true).build(); @@ -98,6 +114,7 @@ void shouldEstablishConnection(final int connectTimeoutMs) throws IOException, I try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver()); MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class); ServerSocket serverSocket = new ServerSocket(0, 1)) { + SingleResultSpyCaptor singleResultSpyCaptor = new SingleResultSpyCaptor<>(); socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor); @@ -147,4 +164,35 @@ public T answer(final InvocationOnMock invocationOnMock) throws Throwable { private static OperationContext createOperationContext(final int connectTimeoutMs) { return simpleOperationContext(new TimeoutContext(TimeoutSettings.DEFAULT.withConnectTimeoutMS(connectTimeoutMs))); } + + @Test + @DisplayName("should not call beginHandshake more than once during TLS session establishment") + void shouldNotCallBeginHandshakeMoreThenOnceDuringTlsSessionEstablishment() throws Exception { + assumeTrue(ClusterFixture.getSslSettings().isEnabled()); + + //given + try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver())) { + + SSLContext sslContext = Mockito.spy(SSLContext.getDefault()); + SingleResultSpyCaptor singleResultSpyCaptor = new SingleResultSpyCaptor<>(); + when(sslContext.createSSLEngine(anyString(), anyInt())).thenAnswer(singleResultSpyCaptor); + + StreamFactory streamFactory = streamFactoryFactory.create( + SocketSettings.builder().build(), + SslSettings.builder(ClusterFixture.getSslSettings()) + .context(sslContext) + .build()); + + Stream stream = streamFactory.create(getPrimaryServerDescription().getAddress()); + stream.open(ClusterFixture.OPERATION_CONTEXT); + ByteBuf wrap = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 3, 4})); + + //when + stream.write(Collections.singletonList(wrap), ClusterFixture.OPERATION_CONTEXT); + + //then + SECONDS.sleep(5); + verify(singleResultSpyCaptor.getResult(), times(1)).beginHandshake(); + } + } }