Skip to content

Commit a5d0419

Browse files
xinlian12Annie Liang
andauthored
RntbdMockServer (Azure#17200)
* Bring in RntbdMockServer and add connectionStateListener test on top of it Co-authored-by: Annie Liang <[email protected]>
1 parent 3f7dee5 commit a5d0419

23 files changed

+3350
-0
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.azure.cosmos.implementation.directconnectivity;
5+
6+
import com.azure.cosmos.DirectConnectionConfig;
7+
import com.azure.cosmos.implementation.Configs;
8+
import com.azure.cosmos.implementation.ConnectionPolicy;
9+
import com.azure.cosmos.implementation.Document;
10+
import com.azure.cosmos.implementation.OperationType;
11+
import com.azure.cosmos.implementation.ResourceType;
12+
import com.azure.cosmos.implementation.RxDocumentServiceRequest;
13+
import com.azure.cosmos.implementation.UserAgentContainer;
14+
import com.azure.cosmos.implementation.directconnectivity.TcpServerMock.TcpServerFactory;
15+
import com.azure.cosmos.implementation.directconnectivity.TcpServerMock.TcpServer;
16+
import com.azure.cosmos.implementation.directconnectivity.TcpServerMock.RequestResponseType;
17+
import com.azure.cosmos.implementation.directconnectivity.TcpServerMock.SslContextUtils;
18+
import com.azure.cosmos.implementation.routing.PartitionKeyRangeIdentity;
19+
import io.netty.handler.ssl.SslContext;
20+
import org.mockito.Mockito;
21+
import org.slf4j.Logger;
22+
import org.slf4j.LoggerFactory;
23+
import org.testng.annotations.DataProvider;
24+
import org.testng.annotations.Test;
25+
26+
import java.util.HashMap;
27+
import java.util.UUID;
28+
import java.util.concurrent.ExecutionException;
29+
30+
import static com.azure.cosmos.implementation.TestUtils.mockDiagnosticsClientContext;
31+
32+
public class ConnectionStateListenerTest {
33+
34+
private static int port = 8082;
35+
private static String serverUriString = "rntbd://localhost:" + port;
36+
private static final Logger logger = LoggerFactory.getLogger(ConnectionStateListenerTest.class);
37+
38+
@DataProvider(name = "connectionStateListenerConfigProvider")
39+
public Object[][] connectionStateListenerConfigProvider() {
40+
return new Object[][]{
41+
// isTcpConnectionEndpointRediscoveryEnabled, serverResponseType, GlobalAddressResolver.updateAddresses() called times
42+
{true, RequestResponseType.CHANNEL_FIN, 1},
43+
{false, RequestResponseType.CHANNEL_FIN, 0},
44+
{true, RequestResponseType.CHANNEL_RST, 0},
45+
{false, RequestResponseType.CHANNEL_RST, 0},
46+
};
47+
}
48+
49+
@Test(groups = { "unit" }, dataProvider = "connectionStateListenerConfigProvider")
50+
public void connectionStateListener_OnConnectionEvent(
51+
boolean isTcpConnectionEndpointRediscoveryEnabled,
52+
RequestResponseType responseType,
53+
int times) throws ExecutionException, InterruptedException {
54+
55+
TcpServer server = TcpServerFactory.startNewRntbdServer(port);
56+
// Inject fake response
57+
server.injectServerResponse(responseType);
58+
59+
ConnectionPolicy connectionPolicy = new ConnectionPolicy(DirectConnectionConfig.getDefaultConfig());
60+
connectionPolicy.setTcpConnectionEndpointRediscoveryEnabled(isTcpConnectionEndpointRediscoveryEnabled);
61+
62+
GlobalAddressResolver addressResolver = Mockito.mock(GlobalAddressResolver.class);
63+
64+
SslContext sslContext = SslContextUtils.CreateSslContext("client.jks", false);
65+
66+
Configs config = Mockito.mock(Configs.class);
67+
Mockito.doReturn(sslContext).when(config).getSslContext();
68+
69+
RntbdTransportClient client = new RntbdTransportClient(
70+
config,
71+
connectionPolicy,
72+
new UserAgentContainer(),
73+
addressResolver);
74+
75+
RxDocumentServiceRequest req =
76+
RxDocumentServiceRequest.create(mockDiagnosticsClientContext(), OperationType.Create, ResourceType.Document,
77+
"dbs/fakedb/colls/fakeColls",
78+
getDocumentDefinition(), new HashMap<>());
79+
req.setPartitionKeyRangeIdentity(new PartitionKeyRangeIdentity("fakeCollectionId","fakePartitionKeyRangeId"));
80+
81+
Uri targetUri = new Uri(serverUriString);
82+
try {
83+
client.invokeStoreAsync(targetUri, req).block();
84+
} catch (Exception e) {
85+
logger.info("expected failed request with reason {}", e);
86+
}
87+
finally {
88+
Mockito.verify(addressResolver, Mockito.times(times)).updateAddresses(Mockito.any(), Mockito.any());
89+
}
90+
91+
TcpServerFactory.shutdownRntbdServer(server);
92+
}
93+
94+
private Document getDocumentDefinition() {
95+
String uuid = UUID.randomUUID().toString();
96+
Document doc = new Document(String.format("{ "
97+
+ "\"id\": \"%s\", "
98+
+ "\"mypk\": \"%s\", "
99+
+ "\"sgmts\": [[6519456, 1471916863], [2498434, 1455671440]]"
100+
+ "}"
101+
, uuid, uuid));
102+
return doc;
103+
}
104+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.azure.cosmos.implementation.directconnectivity.TcpServerMock;
5+
6+
/**
7+
* Use this method to indicate the response type client want to receive.
8+
*/
9+
public enum RequestResponseType {
10+
CHANNEL_FIN,
11+
CHANNEL_RST
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.azure.cosmos.implementation.directconnectivity.TcpServerMock;
5+
6+
import io.netty.handler.ssl.SslContext;
7+
import io.netty.handler.ssl.SslContextBuilder;
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
10+
11+
import javax.net.ssl.KeyManagerFactory;
12+
import javax.net.ssl.TrustManagerFactory;
13+
import java.io.InputStream;
14+
import java.security.KeyStore;
15+
16+
/**
17+
*
18+
* The .jks file can be generated as following:
19+
* // openssl req -x509 -nodes -days 365 -newkey rsa:4096 -keyout testKey.pem -out testCert.crt -subj /CN=localhost
20+
* // keytool -import -v -trustcacerts -alias client-alias -file testCert.pem -keystore client.jks -keypass rntbdTest -storepass rntbdTest
21+
* // openssl pkcs12 -export -in cert.pem -inkey testKey.pem -certfile testCert.pem -out keystore.p12
22+
* // keytool -importkeystore -srckeystore keystore.p12 -srcstoretype pkcs12 -destkeystore server.jks -deststoretype JKS
23+
*
24+
*/
25+
public class SslContextUtils {
26+
private static final String STOREPASS = "rntbdTest";
27+
private static final Logger logger = LoggerFactory.getLogger(SslContextUtils.class);
28+
29+
public static SslContext CreateSslContext(String keyStore, boolean isServer) {
30+
SslContext sslContext = null;
31+
32+
try {
33+
final ClassLoader classloader = SslContextUtils.class.getClassLoader();
34+
final InputStream inputStream = classloader.getResourceAsStream(keyStore);
35+
36+
final KeyStore trustStore = KeyStore.getInstance("jks");
37+
trustStore.load(inputStream, STOREPASS.toCharArray());
38+
39+
final KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
40+
keyManagerFactory.init(trustStore, STOREPASS.toCharArray());
41+
42+
final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
43+
trustManagerFactory.init(trustStore);
44+
45+
if (isServer) {
46+
sslContext = SslContextBuilder.forServer(keyManagerFactory).trustManager(trustManagerFactory).build();
47+
} else {
48+
sslContext = SslContextBuilder.forClient().keyManager(keyManagerFactory).trustManager(trustManagerFactory).build();
49+
}
50+
} catch (Exception exception) {
51+
logger.error("Initializing sslContext failed {}", exception);
52+
}
53+
54+
return sslContext;
55+
}
56+
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* The MIT License (MIT)
3+
* Copyright (c) 2018 Microsoft Corporation
4+
*
5+
* Permission is hereby granted, free of charge, to any person obtaining a copy
6+
* of this software and associated documentation files (the "Software"), to deal
7+
* in the Software without restriction, including without limitation the rights
8+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
* copies of the Software, and to permit persons to whom the Software is
10+
* furnished to do so, subject to the following conditions:
11+
*
12+
* The above copyright notice and this permission notice shall be included in all
13+
* copies or substantial portions of the Software.
14+
*
15+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
* SOFTWARE.
22+
*
23+
*/
24+
25+
package com.azure.cosmos.implementation.directconnectivity.TcpServerMock;
26+
27+
import com.azure.cosmos.implementation.Utils;
28+
import com.azure.cosmos.implementation.directconnectivity.TcpServerMock.rntbd.ServerRntbdContextEncoder;
29+
import com.azure.cosmos.implementation.directconnectivity.TcpServerMock.rntbd.ServerRntbdContextRequestDecoder;
30+
import com.azure.cosmos.implementation.directconnectivity.TcpServerMock.rntbd.ServerRntbdRequestDecoder;
31+
import com.azure.cosmos.implementation.directconnectivity.TcpServerMock.rntbd.ServerRntbdRequestFramer;
32+
import com.azure.cosmos.implementation.directconnectivity.TcpServerMock.rntbd.ServerRntbdRequestManager;
33+
import io.netty.bootstrap.ServerBootstrap;
34+
import io.netty.channel.ChannelFuture;
35+
import io.netty.channel.ChannelInitializer;
36+
import io.netty.channel.ChannelOption;
37+
import io.netty.channel.ChannelPipeline;
38+
import io.netty.channel.EventLoopGroup;
39+
import io.netty.channel.nio.NioEventLoopGroup;
40+
import io.netty.channel.socket.SocketChannel;
41+
import io.netty.channel.socket.nio.NioServerSocketChannel;
42+
import io.netty.handler.logging.LogLevel;
43+
import io.netty.handler.logging.LoggingHandler;
44+
import io.netty.handler.ssl.SslContext;
45+
import io.netty.handler.ssl.SslHandler;
46+
import io.netty.util.concurrent.Promise;
47+
import org.slf4j.Logger;
48+
import org.slf4j.LoggerFactory;
49+
50+
import javax.net.ssl.SSLEngine;
51+
52+
public class TcpServer {
53+
54+
private final static Logger logger = LoggerFactory.getLogger(TcpServer.class);
55+
private static final String SERVER_KEYSTORE = "server.jks";
56+
private final int port;
57+
private final EventLoopGroup parent;
58+
private final EventLoopGroup child;
59+
private final ServerRntbdRequestManager requestManager; // Use to inject fake server response.
60+
61+
public TcpServer(int port) {
62+
this.port = port;
63+
this.parent = new NioEventLoopGroup();
64+
this.child = new NioEventLoopGroup();
65+
requestManager = new ServerRntbdRequestManager();
66+
}
67+
68+
public void start(Promise<Boolean> promise) throws InterruptedException {
69+
70+
SslContext sslContext = SslContextUtils.CreateSslContext(SERVER_KEYSTORE, true);
71+
Utils.checkNotNullOrThrow(sslContext, "sslContext", "");
72+
73+
try {
74+
ServerBootstrap bootstrap = new ServerBootstrap();
75+
76+
bootstrap.group(parent, child)
77+
.channel(NioServerSocketChannel.class)
78+
.childHandler(new ChannelInitializer<SocketChannel>() {
79+
@Override
80+
public void initChannel(SocketChannel channel) throws Exception {
81+
82+
SSLEngine engine = sslContext.newEngine(channel.alloc());
83+
engine.setUseClientMode(false);
84+
85+
ChannelPipeline pipeline = channel.pipeline();
86+
pipeline.addLast(
87+
new SslHandler(engine),
88+
new ServerRntbdRequestFramer(),
89+
new ServerRntbdRequestDecoder(),
90+
new ServerRntbdContextRequestDecoder(),
91+
new ServerRntbdContextEncoder(),
92+
requestManager
93+
);
94+
95+
LogLevel logLevel = null;
96+
97+
if (logger.isTraceEnabled()) {
98+
logLevel = LogLevel.TRACE;
99+
} else if (logger.isDebugEnabled()) {
100+
logLevel = LogLevel.DEBUG;
101+
}
102+
103+
if (logLevel != null) {
104+
pipeline.addFirst(new LoggingHandler(logLevel));
105+
}
106+
}
107+
})
108+
.childOption(ChannelOption.SO_KEEPALIVE, true);
109+
110+
ChannelFuture channelFuture = bootstrap.bind(port).sync().addListener((ChannelFuture f) -> {
111+
if (f.isSuccess()) {
112+
logger.info(
113+
"{} started and listening for connections on {}",
114+
TcpServer.class.getSimpleName(),
115+
f.channel().localAddress());
116+
promise.setSuccess(Boolean.TRUE);
117+
}
118+
});
119+
120+
channelFuture.channel().closeFuture().sync().addListener((ChannelFuture f) -> {
121+
logger.info("Server channel closed.");
122+
});
123+
124+
} finally {
125+
parent.shutdownGracefully().sync();
126+
child.shutdownGracefully().sync();
127+
}
128+
}
129+
130+
public void shutdown(Promise<Boolean> promise) {
131+
try {
132+
parent.shutdownGracefully().sync();
133+
child.shutdownGracefully().sync();
134+
promise.setSuccess(Boolean.TRUE);
135+
} catch (InterruptedException e) {
136+
logger.error("Error when shutting down server {}", e);
137+
promise.setFailure(e);
138+
}
139+
}
140+
141+
public void injectServerResponse(RequestResponseType responseType) {
142+
this.requestManager.injectServerResponse(responseType);
143+
}
144+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.azure.cosmos.implementation.directconnectivity.TcpServerMock;
5+
6+
import io.netty.util.concurrent.DefaultPromise;
7+
import io.netty.util.concurrent.GlobalEventExecutor;
8+
import io.netty.util.concurrent.Promise;
9+
import org.slf4j.Logger;
10+
import org.slf4j.LoggerFactory;
11+
12+
import java.util.concurrent.ExecutionException;
13+
import java.util.concurrent.ExecutorService;
14+
import java.util.concurrent.Executors;
15+
16+
public class TcpServerFactory {
17+
18+
private final static ExecutorService executor = Executors.newFixedThreadPool(10);
19+
private final static Logger logger = LoggerFactory.getLogger(TcpServerFactory.class);
20+
21+
public static TcpServer startNewRntbdServer(int port) throws ExecutionException, InterruptedException {
22+
TcpServer server = new TcpServer(port);
23+
24+
Promise<Boolean> promise = new DefaultPromise<Boolean>(GlobalEventExecutor.INSTANCE);
25+
executor.execute(() -> {
26+
try {
27+
server.start(promise);
28+
} catch (InterruptedException e) {
29+
logger.error("Failed to start server {}", e);
30+
}
31+
});
32+
33+
// only return server when server has started successfully.
34+
promise.get();
35+
return server;
36+
}
37+
38+
public static void shutdownRntbdServer(TcpServer server) throws ExecutionException, InterruptedException {
39+
Promise<Boolean> promise = new DefaultPromise<Boolean>(GlobalEventExecutor.INSTANCE);
40+
executor.execute(() -> server.shutdown(promise));
41+
// only return when server has shutdown.
42+
promise.get();
43+
return;
44+
}
45+
}

0 commit comments

Comments
 (0)