Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable MTLS_S2A bound token by default for gRPC S2A enabled flows #3591

Merged
merged 14 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import com.google.common.io.Files;
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.CompositeChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
Expand Down Expand Up @@ -592,6 +593,41 @@ ChannelCredentials createS2ASecuredChannelCredentials() {
}
}

boolean isMtlsS2AHardBoundTokensEnabled() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a thought (nothing that needs to be changed in this PR): With how many helper methods we have for S2A and hard bound tokens, I wonder if we can split these methods into a helper class in Gax-Grpc (something like S2AMtlsContext or something)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this would be help to reduce the complexity in the InstantiatingGrpcChannelProvider file. I'm happy to do the cleanup of that in a followup CL.

if (!useS2A) {
// If S2A cannot be used, {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens should not be
// used
return false;
}
if (allowedHardBoundTokenTypes == null
lqiu96 marked this conversation as resolved.
Show resolved Hide resolved
lqiu96 marked this conversation as resolved.
Show resolved Hide resolved
|| credentials == null
|| !(credentials instanceof ComputeEngineCredentials)) {
return false;
}
for (HardBoundTokenTypes boundTokenTypes : allowedHardBoundTokenTypes) {
if (boundTokenTypes == HardBoundTokenTypes.MTLS_S2A) {
return true;
}
}
return false;
lqiu96 marked this conversation as resolved.
Show resolved Hide resolved
}

CallCredentials createHardBoundTokensCallCredentials(
ComputeEngineCredentials.GoogleAuthTransport googleAuthTransport,
ComputeEngineCredentials.BindingEnforcement bindingEnforcement) {
ComputeEngineCredentials.Builder credsBuilder =
((ComputeEngineCredentials) credentials).toBuilder();
// We only set scopes and HTTP transport factory from the original credentials because
// only those are used in gRPC CallCredentials to fetch request metadata.
return MoreCallCredentials.from(
ComputeEngineCredentials.newBuilder()
.setScopes(credsBuilder.getScopes())
.setHttpTransportFactory(credsBuilder.getHttpTransportFactory())
.setGoogleAuthTransport(googleAuthTransport)
.setBindingEnforcement(bindingEnforcement)
lqiu96 marked this conversation as resolved.
Show resolved Hide resolved
lqiu96 marked this conversation as resolved.
Show resolved Hide resolved
.build());
}

private ManagedChannel createSingleChannel() throws IOException {
GrpcHeaderInterceptor headerInterceptor =
new GrpcHeaderInterceptor(headersWithDuplicatesRemoved);
Expand Down Expand Up @@ -648,6 +684,15 @@ private ManagedChannel createSingleChannel() throws IOException {
}
if (channelCredentials != null) {
// Create the channel using S2A-secured channel credentials.
if (isMtlsS2AHardBoundTokensEnabled()) {
// Set a {@code ComputeEngineCredentials} instance to be per-RPC call credentials,
// which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
CallCredentials callCreds =
createHardBoundTokensCallCredentials(
ComputeEngineCredentials.GoogleAuthTransport.MTLS,
ComputeEngineCredentials.BindingEnforcement.ON);
channelCredentials = CompositeChannelCredentials.create(channelCredentials, callCreds);
}
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
} else {
// Use default if we cannot initialize channel credentials via DCA or S2A.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,82 @@ void createS2ASecuredChannelCredentials_returnsPlaintextToS2AS2AChannelCredentia
InstantiatingGrpcChannelProvider.LOG.removeHandler(logHandler);
}

@Test
void isMtlsS2AHardBoundTokensEnabled_useS2AFalse() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(false)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(computeEngineCredentials)
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_hardBoundTokenTypesNull() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(null)
.setCredentials(computeEngineCredentials)
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_nullCreds() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(null)
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_notComputeEngineCreds() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(CloudShellCredentials.create(3000))
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_mtlsS2ANotInList() {
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS))
.setCredentials(computeEngineCredentials)
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isFalse();
}

@Test
void isMtlsS2AHardBoundTokensEnabled_returnsTrue() {
lqiu96 marked this conversation as resolved.
Show resolved Hide resolved
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setUseS2A(true)
.setAllowHardBoundTokenTypes(
Collections.singletonList(
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
.setCredentials(computeEngineCredentials)
.build();
Truth.assertThat(provider.isMtlsS2AHardBoundTokensEnabled()).isTrue();
}

private static class FakeLogHandler extends Handler {

List<LogRecord> records = new ArrayList<>();
Expand Down
Loading