diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java index b5568efe400..8ec02f4f809 100644 --- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -59,15 +59,17 @@ final class GcpAuthenticationFilter implements Filter { static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig"; - + private final LruCache callCredentialsCache; final String filterInstanceName; - GcpAuthenticationFilter(String name) { + GcpAuthenticationFilter(String name, int cacheSize) { filterInstanceName = checkNotNull(name, "name"); + this.callCredentialsCache = new LruCache<>(cacheSize); } - static final class Provider implements Filter.Provider { + private final int cacheSize = 10; + @Override public String[] typeUrls() { return new String[]{TYPE_URL}; @@ -80,7 +82,7 @@ public boolean isClientFilter() { @Override public GcpAuthenticationFilter newInstance(String name) { - return new GcpAuthenticationFilter(name); + return new GcpAuthenticationFilter(name, cacheSize); } @Override @@ -101,11 +103,14 @@ public ConfigOrError parseFilterConfig(Message rawProto // Validate cache_config if (gcpAuthnProto.hasCacheConfig()) { TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig(); - cacheSize = cacheConfig.getCacheSize().getValue(); - if (cacheSize == 0) { - return ConfigOrError.fromError( - "cache_config.cache_size must be greater than zero"); + if (cacheConfig.hasCacheSize()) { + cacheSize = cacheConfig.getCacheSize().getValue(); + if (cacheSize == 0) { + return ConfigOrError.fromError( + "cache_config.cache_size must be greater than zero"); + } } + // LruCache's size is an int and briefly exceeds its maximum size before evicting entries cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1); } @@ -127,8 +132,9 @@ public ClientInterceptor buildClientInterceptor(FilterConfig config, @Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) { ComputeEngineCredentials credentials = ComputeEngineCredentials.create(); - LruCache callCredentialsCache = - new LruCache<>(((GcpAuthenticationConfig) config).getCacheSize()); + synchronized (callCredentialsCache) { + callCredentialsCache.resizeCache(((GcpAuthenticationConfig) config).getCacheSize()); + } return new ClientInterceptor() { @Override public ClientCall interceptCall( @@ -254,23 +260,37 @@ public void sendMessage(ReqT message) {} private static final class LruCache { - private final Map cache; + private Map cache; + private int maxSize; LruCache(int maxSize) { - this.cache = new LinkedHashMap( - maxSize, - 0.75f, - true) { - @Override - protected boolean removeEldestEntry(Map.Entry eldest) { - return size() > maxSize; - } - }; + this.maxSize = maxSize; + this.cache = createEvictingMap(maxSize); } V getOrInsert(K key, Function create) { return cache.computeIfAbsent(key, create); } + + private void resizeCache(int newSize) { + if (newSize >= maxSize) { + maxSize = newSize; + return; + } + Map newCache = createEvictingMap(newSize); + maxSize = newSize; + newCache.putAll(cache); + cache = newCache; + } + + private Map createEvictingMap(int size) { + return new LinkedHashMap(size, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > LruCache.this.maxSize; + } + }; + } } static class AudienceMetadataParser implements MetadataValueParser { diff --git a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java index a5e142b4094..d84d8c9d768 100644 --- a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java @@ -28,11 +28,13 @@ import static io.grpc.xds.XdsTestUtils.getWrrLbConfigAsMap; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import com.google.common.collect.ImmutableList; @@ -89,8 +91,8 @@ public class GcpAuthenticationFilterTest { @Test public void testNewFilterInstancesPerFilterName() { - assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1")) - .isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1")); + assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10)) + .isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10)); } @Test @@ -152,7 +154,7 @@ public void testClientInterceptor_success() throws IOException, ResourceInvalidE .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = Mockito.mock(Channel.class); @@ -181,7 +183,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials() .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = Mockito.mock(Channel.class); @@ -190,7 +192,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials() interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); - verify(mockChannel, Mockito.times(2)) + verify(mockChannel, times(2)) .newCall(eq(methodDescriptor), callOptionsCaptor.capture()); CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0); CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1); @@ -202,7 +204,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials() @Test public void testClientInterceptor_withoutClusterSelectionKey() throws Exception { GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = mock(Channel.class); @@ -233,7 +235,7 @@ public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exce Channel mockChannel = mock(Channel.class); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); @@ -244,7 +246,7 @@ public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exce @Test public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception { GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = mock(Channel.class); @@ -274,7 +276,7 @@ public void testClientInterceptor_incorrectClusterName() throws Exception { .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster") .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = mock(Channel.class); @@ -300,7 +302,7 @@ public void testClientInterceptor_statusOrError() throws Exception { .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = mock(Channel.class); @@ -329,7 +331,7 @@ public void testClientInterceptor_notAudienceWrapper() .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10); ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = Mockito.mock(Channel.class); @@ -342,6 +344,115 @@ public void testClientInterceptor_notAudienceWrapper() assertThat(clientCall.error.getDescription()).contains("GCP Authn found wrong type"); } + @Test + public void testLruCacheAcrossInterceptors() throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2); + ClientInterceptor interceptor1 + = filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + + interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions capturedOptions1 = callOptionsCaptor.getAllValues().get(0); + assertNotNull(capturedOptions1.getCredentials()); + ClientInterceptor interceptor2 + = filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + verify(mockChannel, times(2)) + .newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions capturedOptions2 = callOptionsCaptor.getAllValues().get(1); + assertNotNull(capturedOptions2.getCredentials()); + + assertSame(capturedOptions1.getCredentials(), capturedOptions2.getCredentials()); + } + + @Test + public void testLruCacheEvictionOnResize() throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + + ClientInterceptor interceptor1 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null); + Channel mockChannel1 = Mockito.mock(Channel.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(CallOptions.class); + interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel1); + verify(mockChannel1).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options1 = captor.getValue(); + // This will recreate the cache with max size of 1 and copy the credential for audience1. + ClientInterceptor interceptor2 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + Channel mockChannel2 = Mockito.mock(Channel.class); + interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel2); + verify(mockChannel2).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options2 = captor.getValue(); + + assertSame(options1.getCredentials(), options2.getCredentials()); + + clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, getCdsUpdate2(), new EndpointConfig(StatusOr.fromValue(edsUpdate))); + defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + + // This will evict the credential for audience1 and add new credential for audience2 + ClientInterceptor interceptor3 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + Channel mockChannel3 = Mockito.mock(Channel.class); + interceptor3.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel3); + verify(mockChannel3).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options3 = captor.getValue(); + + assertNotSame(options1.getCredentials(), options3.getCredentials()); + + clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate))); + defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + + // This will create new credential for audience1 because it has been evicted + ClientInterceptor interceptor4 = + filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null); + Channel mockChannel4 = Mockito.mock(Channel.class); + interceptor4.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel4); + verify(mockChannel4).newCall(eq(methodDescriptor), captor.capture()); + CallOptions options4 = captor.getValue(); + + assertNotSame(options1.getCredentials(), options4.getCredentials()); + } + private static LdsUpdate getLdsUpdate() { Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig( serverName, RouterFilter.ROUTER_CONFIG); @@ -384,6 +495,19 @@ private static CdsUpdate getCdsUpdate() { } } + private static CdsUpdate getCdsUpdate2() { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("NEW_TEST_AUDIENCE")); + try { + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } catch (IOException ex) { + return null; + } + } + private static CdsUpdate getCdsUpdateWithIncorrectAudienceWrapper() throws IOException { ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); parsedMetadata.put("FILTER_INSTANCE_NAME", "TEST_AUDIENCE");