Skip to content

Merge changes from tls-channel to prevent accidentally calling SSLEng… #1726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ public TlsChannelImpl(
private final Lock readLock = new ReentrantLock();
private final Lock writeLock = new ReentrantLock();

private volatile boolean negotiated = false;
private boolean handshakeStarted = false;

private volatile boolean handshakeCompleted = false;

/**
* Whether a IOException was received from the underlying channel or from the {@link SSLEngine}.
Expand Down Expand Up @@ -526,22 +528,35 @@ public void handshake() throws IOException {
}

private void doHandshake(boolean force) throws IOException, EofException {
if (!force && negotiated) return;
if (!force && handshakeCompleted) {
return;
}
initLock.lock();
try {
if (invalid || shutdownSent) throw new ClosedChannelException();
if (force || !negotiated) {
engine.beginHandshake();
LOGGER.trace("Called engine.beginHandshake()");
if (force || !handshakeCompleted) {

if (!handshakeStarted) {
engine.beginHandshake();
LOGGER.trace("Called engine.beginHandshake()");

// Some engines that do not support renegotiations may be sensitive to calling
// SSLEngine.beginHandshake() more than once. This guard prevents that.
// See: https://github.com/marianobarrios/tls-channel/issues/197
handshakeStarted = true;
}

handshake(Optional.empty(), Optional.empty());

handshakeCompleted = true;

// call client code
try {
initSessionCallback.accept(engine.getSession());
} catch (Exception e) {
LOGGER.trace("client code threw exception in session initialization callback", e);
throw new TlsChannelCallbackException("session initialization callback failed", e);
}
negotiated = true;
}
} finally {
initLock.unlock();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,52 @@

package com.mongodb.internal.connection;

import com.mongodb.ClusterFixture;
import com.mongodb.MongoSocketOpenException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.TimeoutSettings;
import org.bson.ByteBuf;
import org.bson.ByteBufNIO;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.nio.channels.InterruptedByTimeoutException;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.concurrent.TimeUnit;

import static com.mongodb.ClusterFixture.getPrimaryServerDescription;
import static com.mongodb.internal.connection.OperationContext.simpleOperationContext;
import static java.lang.String.format;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class TlsChannelStreamFunctionalTest {
private static final SslSettings SSL_SETTINGS = SslSettings.builder().enabled(true).build();
Expand Down Expand Up @@ -98,6 +114,7 @@ void shouldEstablishConnection(final int connectTimeoutMs) throws IOException, I
try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver());
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class);
ServerSocket serverSocket = new ServerSocket(0, 1)) {

SingleResultSpyCaptor<SocketChannel> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor);

Expand Down Expand Up @@ -147,4 +164,35 @@ public T answer(final InvocationOnMock invocationOnMock) throws Throwable {
private static OperationContext createOperationContext(final int connectTimeoutMs) {
return simpleOperationContext(new TimeoutContext(TimeoutSettings.DEFAULT.withConnectTimeoutMS(connectTimeoutMs)));
}

@Test
@DisplayName("should not call beginHandshake more than once during TLS session establishment")
void shouldNotCallBeginHandshakeMoreThenOnceDuringTlsSessionEstablishment() throws Exception {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test case, as upstream didn't include one to cover this change.

assumeTrue(ClusterFixture.getSslSettings().isEnabled());

//given
try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver())) {

SSLContext sslContext = Mockito.spy(SSLContext.getDefault());
SingleResultSpyCaptor<SSLEngine> singleResultSpyCaptor = new SingleResultSpyCaptor<>();
when(sslContext.createSSLEngine(anyString(), anyInt())).thenAnswer(singleResultSpyCaptor);

StreamFactory streamFactory = streamFactoryFactory.create(
SocketSettings.builder().build(),
SslSettings.builder(ClusterFixture.getSslSettings())
.context(sslContext)
.build());

Stream stream = streamFactory.create(getPrimaryServerDescription().getAddress());
stream.open(ClusterFixture.OPERATION_CONTEXT);
ByteBuf wrap = new ByteBufNIO(ByteBuffer.wrap(new byte[]{1, 3, 4}));

//when
stream.write(Collections.singletonList(wrap), ClusterFixture.OPERATION_CONTEXT);

//then
SECONDS.sleep(5);
verify(singleResultSpyCaptor.getResult(), times(1)).beginHandshake();
}
}
}