diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index d7c2267c48f..131ae6b6125 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -55,20 +55,19 @@ protected final SslContextBuilder getSslContextBuilder( CertificateValidationContext certificateValidationContextdationContext) throws CertStoreException { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); - // Null rootCertInstance implies hasSystemRootCerts because of the check in - // CertProviderClientSslContextProviderFactory. - if (rootCertInstance != null) { - if (savedSpiffeTrustMap != null) { - sslContextBuilder = sslContextBuilder.trustManager( + if (savedSpiffeTrustMap != null) { + sslContextBuilder = sslContextBuilder.trustManager( + new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContextdationContext)); + } else if (savedTrustedRoots != null) { + sslContextBuilder = sslContextBuilder.trustManager( new XdsTrustManagerFactory( - savedSpiffeTrustMap, + savedTrustedRoots.toArray(new X509Certificate[0]), certificateValidationContextdationContext)); - } else { - sslContextBuilder = sslContextBuilder.trustManager( - new XdsTrustManagerFactory( - savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext)); - } + } else { + // Should be impossible because of the check in CertProviderClientSslContextProviderFactory + throw new IllegalStateException("There must be trusted roots or a SPIFFE trust map"); } if (isMtls()) { sslContextBuilder.keyManager(savedKey, savedCertChain); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 801dabeecb7..2570dcb731b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -24,6 +24,7 @@ import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.internal.security.CommonTlsContextUtil; import io.grpc.xds.internal.security.DynamicSslContextProvider; +import java.io.Closeable; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.List; @@ -34,8 +35,8 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider implements CertificateProvider.Watcher { - @Nullable private final CertificateProviderStore.Handle certHandle; - @Nullable private final CertificateProviderStore.Handle rootCertHandle; + @Nullable private final NoExceptionCloseable certHandle; + @Nullable private final NoExceptionCloseable rootCertHandle; @Nullable private final CertificateProviderInstance certInstance; @Nullable protected final CertificateProviderInstance rootCertInstance; @Nullable protected PrivateKey savedKey; @@ -55,24 +56,33 @@ protected CertProviderSslContextProvider( super(tlsContext, staticCertValidationContext); this.certInstance = certInstance; this.rootCertInstance = rootCertInstance; - String certInstanceName = null; - if (certInstance != null && certInstance.isInitialized()) { - certInstanceName = certInstance.getInstanceName(); + this.isUsingSystemRootCerts = rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()); + boolean createCertInstance = certInstance != null && certInstance.isInitialized(); + boolean createRootCertInstance = rootCertInstance != null && rootCertInstance.isInitialized(); + boolean sharedCertInstance = createCertInstance && createRootCertInstance + && rootCertInstance.getInstanceName().equals(certInstance.getInstanceName()); + if (createCertInstance) { CertificateProviderInfo certProviderInstanceConfig = - getCertProviderConfig(certProviders, certInstanceName); + getCertProviderConfig(certProviders, certInstance.getInstanceName()); + CertificateProvider.Watcher watcher = this; + if (!sharedCertInstance && !isUsingSystemRootCerts) { + watcher = new IgnoreUpdatesWatcher(watcher, /* ignoreRootCertUpdates= */ true); + } + // TODO: Previously we'd hang if certProviderInstanceConfig were null or + // certInstance.isInitialized() == false. Now we'll proceed. Those should be errors, or are + // they impossible and should be assertions? certHandle = certProviderInstanceConfig == null ? null : certificateProviderStore.createOrGetProvider( certInstance.getCertificateName(), certProviderInstanceConfig.pluginName(), certProviderInstanceConfig.config(), - this, - true); + watcher, + true)::close; } else { certHandle = null; } - if (rootCertInstance != null - && rootCertInstance.isInitialized() - && !rootCertInstance.getInstanceName().equals(certInstanceName)) { + if (createRootCertInstance && !sharedCertInstance) { CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, rootCertInstance.getInstanceName()); rootCertHandle = certProviderInstanceConfig == null ? null @@ -80,13 +90,16 @@ protected CertProviderSslContextProvider( rootCertInstance.getCertificateName(), certProviderInstanceConfig.pluginName(), certProviderInstanceConfig.config(), - this, - true); + new IgnoreUpdatesWatcher(this, /* ignoreRootCertUpdates= */ false), + false)::close; + } else if (rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext())) { + SystemRootCertificateProvider systemRootProvider = new SystemRootCertificateProvider(this); + systemRootProvider.start(); + rootCertHandle = systemRootProvider::close; } else { rootCertHandle = null; } - this.isUsingSystemRootCerts = rootCertInstance == null - && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()); } private static CertificateProviderInfo getCertProviderConfig( @@ -150,17 +163,16 @@ public final void updateSpiffeTrustMap(Map> spiffe private void updateSslContextWhenReady() { if (isMtls()) { - if (savedKey != null - && (savedTrustedRoots != null || isUsingSystemRootCerts || savedSpiffeTrustMap != null)) { + if (savedKey != null && (savedTrustedRoots != null || savedSpiffeTrustMap != null)) { updateSslContext(); clearKeysAndCerts(); } - } else if (isClientSideTls()) { + } else if (isRegularTlsAndClientSide()) { if (savedTrustedRoots != null || savedSpiffeTrustMap != null) { updateSslContext(); clearKeysAndCerts(); } - } else if (isServerSideTls()) { + } else if (isRegularTlsAndServerSide()) { if (savedKey != null) { updateSslContext(); clearKeysAndCerts(); @@ -170,8 +182,10 @@ private void updateSslContextWhenReady() { private void clearKeysAndCerts() { savedKey = null; - savedTrustedRoots = null; - savedSpiffeTrustMap = null; + if (!isUsingSystemRootCerts) { + savedTrustedRoots = null; + savedSpiffeTrustMap = null; + } savedCertChain = null; } @@ -179,11 +193,11 @@ protected final boolean isMtls() { return certInstance != null && (rootCertInstance != null || isUsingSystemRootCerts); } - protected final boolean isClientSideTls() { - return rootCertInstance != null && certInstance == null; + protected final boolean isRegularTlsAndClientSide() { + return (rootCertInstance != null || isUsingSystemRootCerts) && certInstance == null; } - protected final boolean isServerSideTls() { + protected final boolean isRegularTlsAndServerSide() { return certInstance != null && rootCertInstance == null; } @@ -201,4 +215,9 @@ public final void close() { rootCertHandle.close(); } } + + interface NoExceptionCloseable extends Closeable { + @Override + void close(); + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java new file mode 100644 index 00000000000..cd9d88be41b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import static java.util.Objects.requireNonNull; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Status; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Map; + +public final class IgnoreUpdatesWatcher implements CertificateProvider.Watcher { + private final CertificateProvider.Watcher delegate; + private final boolean ignoreRootCertUpdates; + + public IgnoreUpdatesWatcher( + CertificateProvider.Watcher delegate, boolean ignoreRootCertUpdates) { + this.delegate = requireNonNull(delegate, "delegate"); + this.ignoreRootCertUpdates = ignoreRootCertUpdates; + } + + @Override + public void updateCertificate(PrivateKey key, List certChain) { + if (ignoreRootCertUpdates) { + delegate.updateCertificate(key, certChain); + } + } + + @Override + public void updateTrustedRoots(List trustedRoots) { + if (!ignoreRootCertUpdates) { + delegate.updateTrustedRoots(trustedRoots); + } + } + + @Override + public void updateSpiffeTrustMap(Map> spiffeTrustMap) { + if (!ignoreRootCertUpdates) { + delegate.updateSpiffeTrustMap(spiffeTrustMap); + } + } + + @Override + public void onError(Status errorStatus) { + delegate.onError(errorStatus); + } + + @VisibleForTesting + public CertificateProvider.Watcher getDelegate() { + return delegate; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java new file mode 100644 index 00000000000..7c60f714e71 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java @@ -0,0 +1,71 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import io.grpc.Status; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +/** + * An non-registered provider for CertProviderSslContextProvider to use the same code path for + * system root certs as provider-obtained certs. + */ +final class SystemRootCertificateProvider extends CertificateProvider { + public SystemRootCertificateProvider(CertificateProvider.Watcher watcher) { + super(new DistributorWatcher(), false); + getWatcher().addWatcher(watcher); + } + + @Override + public void start() { + try { + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + + List trustManagers = Arrays.asList(trustManagerFactory.getTrustManagers()); + List rootCerts = trustManagers.stream() + .filter(X509TrustManager.class::isInstance) + .map(X509TrustManager.class::cast) + .map(trustManager -> Arrays.asList(trustManager.getAcceptedIssuers())) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + getWatcher().updateTrustedRoots(rootCerts); + } catch (KeyStoreException | NoSuchAlgorithmException ex) { + getWatcher().onError(Status.UNAVAILABLE + .withDescription("Could not load system root certs") + .withCause(ex)); + } + } + + @Override + public void close() { + // Unnecessary because there's no more callbacks, but do it for good measure + for (Watcher watcher : getWatcher().getDownstreamWatchers()) { + getWatcher().removeWatcher(watcher); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 23068d665bf..e46e440475a 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -37,6 +37,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; @@ -117,6 +118,8 @@ @RunWith(Parameterized.class) public class XdsSecurityClientServerTest { + private static final String SAN_TO_MATCH = "waterzooi.test.google.be"; + @Parameter public Boolean enableSpiffe; private Boolean originalEnableSpiffe; @@ -206,7 +209,8 @@ public void tlsClientServer_noClientAuthentication() throws Exception { * Uses common_tls_context.combined_validation_context in upstream_tls_context. */ @Test - public void tlsClientServer_useSystemRootCerts_useCombinedValidationContext() throws Exception { + public void tlsClientServer_useSystemRootCerts_noMtls_useCombinedValidationContext() + throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa(); try { setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); @@ -217,7 +221,7 @@ public void tlsClientServer_useSystemRootCerts_useCombinedValidationContext() th UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); + CLIENT_PEM_FILE, true, SAN_TO_MATCH, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -233,7 +237,7 @@ public void tlsClientServer_useSystemRootCerts_useCombinedValidationContext() th * Uses common_tls_context.validation_context in upstream_tls_context. */ @Test - public void tlsClientServer_useSystemRootCerts_validationContext() throws Exception { + public void tlsClientServer_useSystemRootCerts_noMtls_validationContext() throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa().toAbsolutePath(); try { setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); @@ -244,7 +248,7 @@ public void tlsClientServer_useSystemRootCerts_validationContext() throws Except UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false); + CLIENT_PEM_FILE, false, SAN_TO_MATCH, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -255,9 +259,60 @@ public void tlsClientServer_useSystemRootCerts_validationContext() throws Except } } + @Test + public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, SAN_TO_MATCH, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_useSystemRootCerts_failureToMatchSubjAltNames() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, "server1.test.google.in", false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + fail("Expected handshake failure exception"); + } catch (StatusRuntimeException e) { + assertThat(e.getCause()).isInstanceOf(SSLHandshakeException.class); + assertThat(e.getCause().getCause()).isInstanceOf(CertificateException.class); + assertThat(e.getCause().getCause().getMessage()).isEqualTo( + "Peer certificate SAN check failed"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + /** * Use system root ca cert for TLS channel - mTLS. - * Uses common_tls_context.combined_validation_context in upstream_tls_context. */ @Test public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Exception { @@ -266,12 +321,12 @@ public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Except setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, - null, false, false); + null, false, true); buildServerWithTlsContext(downstreamTlsContext); UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); + CLIENT_PEM_FILE, true, SAN_TO_MATCH, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -549,20 +604,25 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String cli .buildUpstreamTlsContext("google_cloud_private_spiffe-client", hasIdentityCert); } + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts( String clientKeyFile, String clientPemFile, - boolean useCombinedValidationContext) { + boolean useCombinedValidationContext, String sanToMatch, boolean isMtls) { bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, CA_PEM_FILE, null, null, null, null, null); if (useCombinedValidationContext) { return CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - "google_cloud_private_spiffe-client", "ROOT", null, + isMtls ? "google_cloud_private_spiffe-client" : null, + isMtls ? "ROOT" : null, null, null, null, CertificateValidationContext.newBuilder() .setSystemRootCerts( CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .addMatchSubjectAltNames( + StringMatcher.newBuilder() + .setExact(sanToMatch)) .build()); } return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index 397fe01e0f5..a0eac581d5c 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -37,6 +37,7 @@ import io.grpc.xds.internal.security.certprovider.CertificateProviderProvider; import io.grpc.xds.internal.security.certprovider.CertificateProviderRegistry; import io.grpc.xds.internal.security.certprovider.CertificateProviderStore; +import io.grpc.xds.internal.security.certprovider.IgnoreUpdatesWatcher; import io.grpc.xds.internal.security.certprovider.TestCertificateProvider; import org.junit.Before; import org.junit.Test; @@ -84,7 +85,7 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); @@ -119,7 +120,7 @@ public void bothPresent_expectCertProviderClientSslContextProvider() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -145,7 +146,7 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -179,7 +180,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -209,8 +210,8 @@ public void createCertProviderClientSslContextProvider_2providers() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -246,8 +247,8 @@ public void createNewCertProviderClientSslContextProvider_withSans() { clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -280,7 +281,7 @@ public void createNewCertProviderClientSslContextProvider_onlyRootCert() { clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } static void createAndRegisterProviderProvider( @@ -310,11 +311,18 @@ public CertificateProvider answer(InvocationOnMock invocation) throws Throwable } static void verifyWatcher( - SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) { + SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor, + boolean usesDelegateWatcher) { assertThat(watcherCaptor).isNotNull(); assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1); - assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) - .isSameInstanceAs(sslContextProvider); + if (usesDelegateWatcher) { + assertThat(((IgnoreUpdatesWatcher) watcherCaptor.getDownstreamWatchers().iterator().next()) + .getDelegate()) + .isSameInstanceAs(sslContextProvider); + } else { + assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) + .isSameInstanceAs(sslContextProvider); + } } static CommonTlsContext.Builder addFilenames( diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java index cf86b511f1f..7a5a6c00639 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java @@ -78,7 +78,7 @@ public void createCertProviderServerSslContextProvider() throws XdsInitializatio serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); @@ -117,7 +117,7 @@ public void bothPresent_expectCertProviderServerSslContextProvider() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -144,7 +144,7 @@ public void createCertProviderServerSslContextProvider_onlyCertInstance() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -179,7 +179,7 @@ public void createCertProviderServerSslContextProvider_withStaticContext() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); } @Test @@ -210,8 +210,8 @@ public void createCertProviderServerSslContextProvider_2providers() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -249,7 +249,7 @@ public void createNewCertProviderServerSslContextProvider_withSans() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index b0800458d66..3c734df3f5a 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -72,15 +72,28 @@ private CertProviderClientSslContextProvider getSslContextProvider( String rootInstanceName, Bootstrapper.BootstrapInfo bootstrapInfo, Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext) { - EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - certInstanceName, - "cert-default", - rootInstanceName, - "root-default", - alpnProtocols, - staticCertValidationContext); + CertificateValidationContext staticCertValidationContext, + boolean useSystemRootCerts) { + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; + if (useSystemRootCerts) { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + } else { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + } return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, @@ -122,7 +135,7 @@ public void testProviderForClient_mtls() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); @@ -173,6 +186,87 @@ public void testProviderForClient_mtls() throws Exception { assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } + @Test + public void testProviderForClient_systemRootCerts_regularTls() { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + null, + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), + true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContext()).isNotNull(); + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback.updatedSslContext).isEqualTo(provider.getSslContext()); + + assertThat(watcherCaptor[0]).isNull(); + } + + @Test + public void testProviderForClient_systemRootCerts_mtls() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + "gcp_id", + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), + true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContext()).isNotNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e. different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContext()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + @Test public void testProviderForClient_mtls_newXds() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = @@ -248,7 +342,7 @@ public void testProviderForClient_queueExecutor() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); QueuedExecutor queuedExecutor = new QueuedExecutor(); TestCallback testCallback = @@ -281,7 +375,7 @@ public void testProviderForClient_tls() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); @@ -318,7 +412,7 @@ public void testProviderForClient_sslContextException_onError() throws Exception "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */null, - staticCertValidationContext); + staticCertValidationContext, false); TestCallback testCallback = new TestCallback(MoreExecutors.directExecutor()); provider.addCallback(testCallback); @@ -350,7 +444,7 @@ public void testProviderForClient_rootInstanceNull_and_notUsingSystemRootCerts_e /* rootInstanceName= */ null, CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); fail("exception expected"); } catch (UnsupportedOperationException expected) { assertThat(expected).hasMessageThat().contains("Unsupported configurations in " @@ -373,7 +467,7 @@ public void testProviderForClient_rootInstanceNull_but_isUsingSystemRootCerts_va CertificateValidationContext.newBuilder() .setSystemRootCerts( CertificateValidationContext.SystemRootCerts.newBuilder().build()) - .build()); + .build(), false); } static class QueuedExecutor implements Executor {