Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable MTLS_S2A bound token by default for gRPC S2A enabled flows #3591

Merged
merged 14 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import com.google.common.io.Files;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.CompositeChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
Expand All @@ -69,6 +70,7 @@
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -139,14 +141,15 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable private final Boolean keepAliveWithoutCalls;
private final ChannelPoolSettings channelPoolSettings;
@Nullable private final Credentials credentials;
@Nullable private final CallCredentials mtlsS2ACallCredentials;
@Nullable private final ChannelPrimer channelPrimer;
@Nullable private final Boolean attemptDirectPath;
@Nullable private final Boolean attemptDirectPathXds;
@Nullable private final Boolean allowNonDefaultServiceAccount;
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;
@Nullable private final MtlsProvider mtlsProvider;
@Nullable private final SecureSessionAgent s2aConfigProvider;
@Nullable private final List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
private final List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
@VisibleForTesting final Map<String, String> headersWithDuplicatesRemoved = new HashMap<>();

@Nullable
Expand Down Expand Up @@ -188,6 +191,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.channelPoolSettings = builder.channelPoolSettings;
this.channelConfigurator = builder.channelConfigurator;
this.credentials = builder.credentials;
this.mtlsS2ACallCredentials = builder.mtlsS2ACallCredentials;
this.channelPrimer = builder.channelPrimer;
this.attemptDirectPath = builder.attemptDirectPath;
this.attemptDirectPathXds = builder.attemptDirectPathXds;
Expand Down Expand Up @@ -648,6 +652,12 @@ private ManagedChannel createSingleChannel() throws IOException {
}
if (channelCredentials != null) {
// Create the channel using S2A-secured channel credentials.
if (mtlsS2ACallCredentials != null) {
// Set {@code mtlsS2ACallCredentials} to be per-RPC call credentials,
// which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
channelCredentials =
CompositeChannelCredentials.create(channelCredentials, mtlsS2ACallCredentials);
}
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
} else {
// Use default if we cannot initialize channel credentials via DCA or S2A.
Expand Down Expand Up @@ -812,18 +822,20 @@ public static final class Builder {
@Nullable private Boolean keepAliveWithoutCalls;
@Nullable private ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
@Nullable private Credentials credentials;
@Nullable private CallCredentials mtlsS2ACallCredentials;
@Nullable private ChannelPrimer channelPrimer;
private ChannelPoolSettings channelPoolSettings;
@Nullable private Boolean attemptDirectPath;
@Nullable private Boolean attemptDirectPathXds;
@Nullable private Boolean allowNonDefaultServiceAccount;
@Nullable private ImmutableMap<String, ?> directPathServiceConfig;
@Nullable private List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
private List<HardBoundTokenTypes> allowedHardBoundTokenTypes;

private Builder() {
processorCount = Runtime.getRuntime().availableProcessors();
envProvider = System::getenv;
channelPoolSettings = ChannelPoolSettings.staticallySized(1);
allowedHardBoundTokenTypes = new ArrayList<>();
}

private Builder(InstantiatingGrpcChannelProvider provider) {
Expand All @@ -841,11 +853,13 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.keepAliveWithoutCalls = provider.keepAliveWithoutCalls;
this.channelConfigurator = provider.channelConfigurator;
this.credentials = provider.credentials;
this.mtlsS2ACallCredentials = provider.mtlsS2ACallCredentials;
this.channelPrimer = provider.channelPrimer;
this.channelPoolSettings = provider.channelPoolSettings;
this.attemptDirectPath = provider.attemptDirectPath;
this.attemptDirectPathXds = provider.attemptDirectPathXds;
this.allowNonDefaultServiceAccount = provider.allowNonDefaultServiceAccount;
this.allowedHardBoundTokenTypes = provider.allowedHardBoundTokenTypes;
this.directPathServiceConfig = provider.directPathServiceConfig;
this.mtlsProvider = provider.mtlsProvider;
this.s2aConfigProvider = provider.s2aConfigProvider;
Expand Down Expand Up @@ -914,7 +928,10 @@ Builder setUseS2A(boolean useS2A) {
*/
@InternalApi
public Builder setAllowHardBoundTokenTypes(List<HardBoundTokenTypes> allowedValues) {
this.allowedHardBoundTokenTypes = allowedValues;
this.allowedHardBoundTokenTypes =
Preconditions.checkNotNull(
allowedValues, "List of allowed HardBoundTokenTypes cannot be null");
;
return this;
}

Expand Down Expand Up @@ -1133,7 +1150,50 @@ public Builder setDirectPathServiceConfig(Map<String, ?> serviceConfig) {
return this;
}

boolean isMtlsS2AHardBoundTokensEnabled() {
// If S2A cannot be used, the list of allowed hard bound token types is empty or doesn't
// contain
// {@code HardBoundTokenTypes.MTLS_S2A}, the {@code credentials} are null or not of type
// {@code
// ComputeEngineCredentials} then {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens
// should
// not
// be used. {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens can only be used on MTLS
// channels established using S2A and when tokens from MDS (i.e {@code
// ComputeEngineCredentials}
// are being used.
if (!this.useS2A
|| this.allowedHardBoundTokenTypes.isEmpty()
|| this.credentials == null
|| !(this.credentials instanceof ComputeEngineCredentials)) {
return false;
}
return allowedHardBoundTokenTypes.stream()
.anyMatch(val -> val.equals(HardBoundTokenTypes.MTLS_S2A));
}

CallCredentials createHardBoundTokensCallCredentials(
ComputeEngineCredentials.GoogleAuthTransport googleAuthTransport,
ComputeEngineCredentials.BindingEnforcement bindingEnforcement) {
// We only set scopes and HTTP transport factory from the original credentials because
// only those are used in gRPC CallCredentials to fetch request metadata.
return MoreCallCredentials.from(
((ComputeEngineCredentials) this.credentials)
.toBuilder()
.setGoogleAuthTransport(googleAuthTransport)
.setBindingEnforcement(bindingEnforcement)
.build());
}

public InstantiatingGrpcChannelProvider build() {
if (isMtlsS2AHardBoundTokensEnabled()) {
// Set a {@code ComputeEngineCredentials} instance to be per-RPC call credentials,
// which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
this.mtlsS2ACallCredentials =
createHardBoundTokensCallCredentials(
ComputeEngineCredentials.GoogleAuthTransport.MTLS,
ComputeEngineCredentials.BindingEnforcement.ON);
}
InstantiatingGrpcChannelProvider instantiatingGrpcChannelProvider =
new InstantiatingGrpcChannelProvider(this);
instantiatingGrpcChannelProvider.removeApiKeyCredentialDuplicateHeaders();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,79 @@ void createS2ASecuredChannelCredentials_returnsPlaintextToS2AS2AChannelCredentia
InstantiatingGrpcChannelProvider.LOG.removeHandler(logHandler);
}

@Test
void isMtlsS2AHardBoundTokensEnabled_useS2AFalse() {
InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(false)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(computeEngineCredentials);
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_hardBoundTokenTypesEmpty() {
InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(new ArrayList<>())
.setCredentials(computeEngineCredentials);
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_nullCreds() {
InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(null);
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_notComputeEngineCreds() {
InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(CloudShellCredentials.create(3000));
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_mtlsS2ANotInList() {
InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS))
.setCredentials(computeEngineCredentials);
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_mtlsS2ATokenAllowedInList() {
List<InstantiatingGrpcChannelProvider.HardBoundTokenTypes> allowHardBoundTokenTypes =
new ArrayList<>();
allowHardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A);
allowHardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS);

InstantiatingGrpcChannelProvider.Builder providerBuilder =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(allowHardBoundTokenTypes)
.setCredentials(computeEngineCredentials);
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isTrue();
}

private static class FakeLogHandler extends Handler {

List<LogRecord> records = new ArrayList<>();
Expand Down
Loading