Skip to content

xds: float LRU cache across interceptors #11992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,17 @@

static final String TYPE_URL =
"type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig";

private final LruCache<String, CallCredentials> callCredentialsCache;
final String filterInstanceName;

GcpAuthenticationFilter(String name) {
GcpAuthenticationFilter(String name, int cacheSize) {
filterInstanceName = checkNotNull(name, "name");
this.callCredentialsCache = new LruCache<>(cacheSize);
}


static final class Provider implements Filter.Provider {
private final int cacheSize = 10;

@Override
public String[] typeUrls() {
return new String[]{TYPE_URL};
Expand All @@ -80,7 +82,7 @@

@Override
public GcpAuthenticationFilter newInstance(String name) {
return new GcpAuthenticationFilter(name);
return new GcpAuthenticationFilter(name, cacheSize);

Check warning on line 85 in xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java#L85

Added line #L85 was not covered by tests
}

@Override
Expand All @@ -101,11 +103,14 @@
// Validate cache_config
if (gcpAuthnProto.hasCacheConfig()) {
TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig();
cacheSize = cacheConfig.getCacheSize().getValue();
if (cacheSize == 0) {
return ConfigOrError.fromError(
"cache_config.cache_size must be greater than zero");
if (cacheConfig.hasCacheSize()) {
cacheSize = cacheConfig.getCacheSize().getValue();
if (cacheSize == 0) {
return ConfigOrError.fromError(
"cache_config.cache_size must be greater than zero");
}
}

// LruCache's size is an int and briefly exceeds its maximum size before evicting entries
cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1);
}
Expand All @@ -127,8 +132,9 @@
@Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) {

ComputeEngineCredentials credentials = ComputeEngineCredentials.create();
LruCache<String, CallCredentials> callCredentialsCache =
new LruCache<>(((GcpAuthenticationConfig) config).getCacheSize());
synchronized (callCredentialsCache) {
callCredentialsCache.resizeCache(((GcpAuthenticationConfig) config).getCacheSize());
}
return new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
Expand Down Expand Up @@ -254,23 +260,37 @@

private static final class LruCache<K, V> {

private final Map<K, V> cache;
private Map<K, V> cache;
private int maxSize;

LruCache(int maxSize) {
this.cache = new LinkedHashMap<K, V>(
maxSize,
0.75f,
true) {
@Override
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
return size() > maxSize;
}
};
this.maxSize = maxSize;
this.cache = createEvictingMap(maxSize);
}

V getOrInsert(K key, Function<K, V> create) {
return cache.computeIfAbsent(key, create);
}

private void resizeCache(int newSize) {
if (newSize >= maxSize) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't understand this. I thought if you get a bigger new size you should resize, and if you get an equal or smaller new size only that should be a no-op because decreasing cache size doesn't make since when another filter instance created a bigger cache.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your point but well, cache_size is configuration coming externally. We should adhere to the configuration. If the cache size decreases, then we can't use the normal code to discard entries, as that only discards a single entry. So without special code, the configuration wouldn't be observed.

It is generally really bad to not observe configuration updates, from a testing/debugging perspective. You do a config push and see behavior X, but then after restarting you see behavior Y. It can be really hard to track down the bugs.

The cache size doesn't change often, so the performance hit would be temporary (although there are still issues with that). It's also much easier to figure out what happened. If you keep the old cache, then someone may be debugging something and it matters what happened a month ago.

I discussed this with Eric offline on the similar lines and understood this POV.

maxSize = newSize;
return;
}
Map<K, V> newCache = createEvictingMap(newSize);
maxSize = newSize;
newCache.putAll(cache);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because maxSize is only updated after the map is created, this map will contain all the entries from cache.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

createEvictingMap is creating an empty LinkedHashMap. There are separator constructors

LinkedHashMap(int initialCapacity, float loadFactor, boolean accessOrder)
LinkedHashMap(Map<? extends K,? extends V> m)

but none that both take an old map and specify the initialCapacity.

cache = newCache;
}

private Map<K, V> createEvictingMap(int size) {
return new LinkedHashMap<K, V>(size, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
return size() > LruCache.this.maxSize;
}
};
}
}

static class AudienceMetadataParser implements MetadataValueParser {
Expand Down
141 changes: 130 additions & 11 deletions xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
import static io.grpc.xds.XdsTestUtils.getWrrLbConfigAsMap;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -89,8 +91,8 @@ public class GcpAuthenticationFilterTest {

@Test
public void testNewFilterInstancesPerFilterName() {
assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1"))
.isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1"));
assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10))
.isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10));
}

@Test
Expand Down Expand Up @@ -152,7 +154,7 @@ public void testClientInterceptor_success() throws IOException, ResourceInvalidE
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
Expand Down Expand Up @@ -181,7 +183,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials()
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
Expand All @@ -190,7 +192,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials()
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);

verify(mockChannel, Mockito.times(2))
verify(mockChannel, times(2))
.newCall(eq(methodDescriptor), callOptionsCaptor.capture());
CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0);
CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1);
Expand All @@ -202,7 +204,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials()
@Test
public void testClientInterceptor_withoutClusterSelectionKey() throws Exception {
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
Expand Down Expand Up @@ -233,7 +235,7 @@ public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exce
Channel mockChannel = mock(Channel.class);

GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
Expand All @@ -244,7 +246,7 @@ public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exce
@Test
public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception {
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
Expand Down Expand Up @@ -274,7 +276,7 @@ public void testClientInterceptor_incorrectClusterName() throws Exception {
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
Expand All @@ -300,7 +302,7 @@ public void testClientInterceptor_statusOrError() throws Exception {
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = mock(Channel.class);
Expand Down Expand Up @@ -329,7 +331,7 @@ public void testClientInterceptor_notAudienceWrapper()
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
Expand All @@ -342,6 +344,110 @@ public void testClientInterceptor_notAudienceWrapper()
assertThat(clientCall.error.getDescription()).contains("GCP Authn found wrong type");
}

@Test
public void testLruCacheAcrossInterceptors() throws IOException, ResourceInvalidException {
XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
.setListener(ldsUpdate)
.setRoute(rdsUpdate)
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
CallOptions callOptionsWithXds = CallOptions.DEFAULT
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2);
ClientInterceptor interceptor1
= filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
Channel mockChannel = Mockito.mock(Channel.class);
ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class);

interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture());
CallOptions capturedOptions1 = callOptionsCaptor.getAllValues().get(0);
assertNotNull(capturedOptions1.getCredentials());
ClientInterceptor interceptor2
= filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
verify(mockChannel, times(2))
.newCall(eq(methodDescriptor), callOptionsCaptor.capture());
CallOptions capturedOptions2 = callOptionsCaptor.getAllValues().get(1);
assertNotNull(capturedOptions2.getCredentials());

assertSame(capturedOptions1.getCredentials(), capturedOptions2.getCredentials());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cached value is tested but not eviction behavior. With cache max size of 1, you need to test using a cluster resource with a 2nd audience string and get the credentials and assert the behavior for the 1st audience string's call credentials getting evicted.

filter max size 2
audience-1 : call-credentials-for-audience-1
filter max size 1
audience-1: call-credentials-1-for-audience-1 (same instance)
audience-2: call-credentials-1-for-audience-2 (causes eviction of call-credentials-1-for-audience-1)
audience-1: call-credentials-2-for-audience-1 (new instance)

}

@Test
public void testLruCacheEvictionOnResize() throws IOException, ResourceInvalidException {
XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
.setListener(ldsUpdate)
.setRoute(rdsUpdate)
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
CallOptions callOptionsWithXds = CallOptions.DEFAULT
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
ClientInterceptor interceptor1 =
filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null);
Channel mockChannel1 = Mockito.mock(Channel.class);
ArgumentCaptor<CallOptions> captor = ArgumentCaptor.forClass(CallOptions.class);
interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel1);
verify(mockChannel1).newCall(eq(methodDescriptor), captor.capture());
CallOptions options1 = captor.getValue();
assertNotNull(options1.getCredentials());
ClientInterceptor interceptor2 =
filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
Channel mockChannel2 = Mockito.mock(Channel.class);
interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel2);
verify(mockChannel2).newCall(eq(methodDescriptor), captor.capture());
CallOptions options2 = captor.getValue();
assertNotNull(options2.getCredentials());
clusterConfig = new XdsConfig.XdsClusterConfig(
CLUSTER_NAME, getCdsUpdate2(), new EndpointConfig(StatusOr.fromValue(edsUpdate)));
defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
.setListener(ldsUpdate)
.setRoute(rdsUpdate)
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
callOptionsWithXds = CallOptions.DEFAULT
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
ClientInterceptor interceptor3 =
filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
Channel mockChannel3 = Mockito.mock(Channel.class);
interceptor3.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel3);
verify(mockChannel3).newCall(eq(methodDescriptor), captor.capture());
CallOptions options3 = captor.getValue();
assertNotNull(options3.getCredentials());
clusterConfig = new XdsConfig.XdsClusterConfig(
CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
.setListener(ldsUpdate)
.setRoute(rdsUpdate)
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
callOptionsWithXds = CallOptions.DEFAULT
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);

ClientInterceptor interceptor4 =
filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
Channel mockChannel4 = Mockito.mock(Channel.class);
interceptor4.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel4);
verify(mockChannel4).newCall(eq(methodDescriptor), captor.capture());
CallOptions options4 = captor.getValue();
assertNotNull(options4.getCredentials());

assertSame(options1.getCredentials(), options2.getCredentials());
assertNotSame(options1.getCredentials(), options3.getCredentials());
assertNotSame(options1.getCredentials(), options4.getCredentials());
}

private static LdsUpdate getLdsUpdate() {
Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig(
serverName, RouterFilter.ROUTER_CONFIG);
Expand Down Expand Up @@ -384,6 +490,19 @@ private static CdsUpdate getCdsUpdate() {
}
}

private static CdsUpdate getCdsUpdate2() {
ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("NEW_TEST_AUDIENCE"));
try {
CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds(
CLUSTER_NAME, EDS_NAME, null, null, null, null, false)
.lbPolicyConfig(getWrrLbConfigAsMap());
return cdsUpdate.parsedMetadata(parsedMetadata.build()).build();
} catch (IOException ex) {
return null;
}
}

private static CdsUpdate getCdsUpdateWithIncorrectAudienceWrapper() throws IOException {
ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
parsedMetadata.put("FILTER_INSTANCE_NAME", "TEST_AUDIENCE");
Expand Down