diff --git a/sdk/identity/azure-identity/pom.xml b/sdk/identity/azure-identity/pom.xml index 82dd64b7a0f5..a3eb2cc64d4a 100644 --- a/sdk/identity/azure-identity/pom.xml +++ b/sdk/identity/azure-identity/pom.xml @@ -115,6 +115,13 @@ 1.17.7 test + + + org.bouncycastle + bcpkix-lts8on + 2.73.8 + test + diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java index a0c04ff2d952..40c3ac8ce4b4 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java @@ -10,6 +10,9 @@ import com.azure.core.util.CoreUtils; import com.azure.core.util.logging.ClientLogger; import com.azure.identity.implementation.IdentityClientOptions; +import com.azure.identity.implementation.customtokenproxy.CustomTokenProxyConfiguration; +import com.azure.identity.implementation.customtokenproxy.CustomTokenProxyHttpClient; +import com.azure.identity.implementation.customtokenproxy.ProxyConfig; import com.azure.identity.implementation.util.LoggingUtil; import com.azure.identity.implementation.util.ValidationUtil; import reactor.core.publisher.Mono; @@ -89,6 +92,13 @@ public class WorkloadIdentityCredential implements TokenCredential { ClientAssertionCredential tempClientAssertionCredential = null; String tempClientId = null; + if (identityClientOptions.isKubernetesTokenProxyEnabled()) { + if (CustomTokenProxyConfiguration.isConfigured(configuration)) { + ProxyConfig proxyConfig = CustomTokenProxyConfiguration.parseAndValidate(configuration); + identityClientOptions.setHttpClient(new CustomTokenProxyHttpClient(proxyConfig)); + } + } + if (!(CoreUtils.isNullOrEmpty(tenantIdInput) || CoreUtils.isNullOrEmpty(federatedTokenFilePathInput) || CoreUtils.isNullOrEmpty(clientIdInput) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java index 95f16fd9628d..31e62c6f4e14 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java @@ -47,6 +47,7 @@ public class WorkloadIdentityCredentialBuilder extends AadCredentialBuilderBase { private static final ClientLogger LOGGER = new ClientLogger(WorkloadIdentityCredentialBuilder.class); private String tokenFilePath; + private boolean enableTokenProxy = false; /** * Creates an instance of a WorkloadIdentityCredentialBuilder. @@ -66,6 +67,19 @@ public WorkloadIdentityCredentialBuilder tokenFilePath(String tokenFilePath) { return this; } + /** + * Enables the Kubernetes token proxy feature for AKS workload identity scenarios. + * When enabled, the credential will attempt to use a custom token proxy configured through + * environment variables (AZURE_KUBERNETES_TOKEN_PROXY, AZURE_KUBERNETES_CA_FILE, + * AZURE_KUBERNETES_CA_DATA, AZURE_KUBERNETES_SNI_NAME). + * + * @return An updated instance of this builder with Kubernetes token proxy enabled. + */ + public WorkloadIdentityCredentialBuilder enableKubernetesTokenProxy() { + this.enableTokenProxy = true; + return this; + } + /** * Creates new {@link WorkloadIdentityCredential} with the configured options set. * @@ -88,6 +102,8 @@ public WorkloadIdentityCredential build() { ValidationUtil.validate(this.getClass().getSimpleName(), LOGGER, "Client ID", clientIdInput, "Tenant ID", tenantIdInput, "Service Token File Path", federatedTokenFilePathInput); + identityClientOptions.setEnableKubernetesTokenProxy(this.enableTokenProxy); + return new WorkloadIdentityCredential(tenantIdInput, clientIdInput, federatedTokenFilePathInput, identityClientOptions.clone()); } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java index d934c23c1d67..3c2d21d08c56 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java @@ -71,6 +71,7 @@ public final class IdentityClientOptions implements Cloneable { private List perRetryPolicies; private boolean instanceDiscovery; private String dacEnvConfiguredCredential; + private boolean enableKubernetesTokenProxy; private Duration credentialProcessTimeout = Duration.ofSeconds(10); @@ -833,6 +834,15 @@ public String getDACEnvConfiguredCredential() { return dacEnvConfiguredCredential; } + public boolean isKubernetesTokenProxyEnabled() { + return enableKubernetesTokenProxy; + } + + public IdentityClientOptions setEnableKubernetesTokenProxy(boolean enableTokenProxy) { + this.enableKubernetesTokenProxy = enableTokenProxy; + return this; + } + public IdentityClientOptions clone() { IdentityClientOptions clone = new IdentityClientOptions().setAdditionallyAllowedTenants(this.additionallyAllowedTenants) @@ -863,7 +873,8 @@ public IdentityClientOptions clone() { .setPerRetryPolicies(this.perRetryPolicies) .setBrowserCustomizationOptions(this.browserCustomizationOptions) .setChained(this.isChained) - .subscription(this.subscription); + .subscription(this.subscription) + .setEnableKubernetesTokenProxy(this.enableKubernetesTokenProxy); if (isBrokerEnabled()) { clone.setBrokerWindowHandle(this.brokerWindowHandle); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java new file mode 100644 index 000000000000..da7eddc1e38b --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.identity.implementation.customtokenproxy; + +import com.azure.core.util.logging.ClientLogger; + +import java.net.URI; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.net.URISyntaxException; + +import com.azure.core.util.Configuration; +import com.azure.core.util.CoreUtils; + +public final class CustomTokenProxyConfiguration { + + private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyConfiguration.class); + + public static final String AZURE_KUBERNETES_TOKEN_PROXY = "AZURE_KUBERNETES_TOKEN_PROXY"; + public static final String AZURE_KUBERNETES_CA_FILE = "AZURE_KUBERNETES_CA_FILE"; + public static final String AZURE_KUBERNETES_CA_DATA = "AZURE_KUBERNETES_CA_DATA"; + public static final String AZURE_KUBERNETES_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME"; + + private CustomTokenProxyConfiguration() { + } + + public static boolean isConfigured(Configuration configuration) { + String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY); + return !CoreUtils.isNullOrEmpty(tokenProxyUrl); + } + + public static ProxyConfig parseAndValidate(Configuration configuration) { + String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY); + String caFile = configuration.get(AZURE_KUBERNETES_CA_FILE); + String caData = configuration.get(AZURE_KUBERNETES_CA_DATA); + String sniName = configuration.get(AZURE_KUBERNETES_SNI_NAME); + + if (CoreUtils.isNullOrEmpty(tokenProxyUrl)) { + if (!CoreUtils.isNullOrEmpty(sniName) + || !CoreUtils.isNullOrEmpty(caFile) + || !CoreUtils.isNullOrEmpty(caData)) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present")); + } + return null; + } + + if (!CoreUtils.isNullOrEmpty(caFile) && !CoreUtils.isNullOrEmpty(caData)) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Only one of AZURE_KUBERNETES_CA_FILE or AZURE_KUBERNETES_CA_DATA can be set.")); + } + + URL proxyUrl = validateProxyUrl(tokenProxyUrl); + + byte[] caCertBytes = null; + if (!CoreUtils.isNullOrEmpty(caData)) { + try { + caCertBytes = caData.getBytes(StandardCharsets.UTF_8); + } catch (Exception e) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Failed to decode CA certificate data from AZURE_KUBERNETES_CA_DATA", e)); + } + } + + ProxyConfig config = new ProxyConfig(proxyUrl, sniName, caFile, caCertBytes); + return config; + } + + private static URL validateProxyUrl(String endpoint) { + if (CoreUtils.isNullOrEmpty(endpoint)) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Proxy endpoint cannot be null or empty")); + } + + try { + URI tokenProxy = new URI(endpoint); + + if (!"https".equals(tokenProxy.getScheme())) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Custom token endpoint must use https scheme, got: " + tokenProxy.getScheme())); + } + + if (tokenProxy.getRawUserInfo() != null) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain user info: " + endpoint)); + } + + if (tokenProxy.getRawQuery() != null) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain a query: " + endpoint)); + } + + if (tokenProxy.getRawFragment() != null) { + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain a fragment: " + endpoint)); + } + + if (tokenProxy.getRawPath() == null || tokenProxy.getRawPath().isEmpty()) { + tokenProxy = new URI(tokenProxy.getScheme(), null, tokenProxy.getHost(), tokenProxy.getPort(), "/", + null, null); + } + + return tokenProxy.toURL(); + + } catch (URISyntaxException | IllegalArgumentException e) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Failed to normalize proxy URL path", e)); + } catch (Exception e) { + throw new RuntimeException("Unexpected error while validating proxy URL: " + endpoint, e); + } + } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java new file mode 100644 index 000000000000..4cad1ba381be --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.identity.implementation.customtokenproxy; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManagerFactory; + +import com.azure.core.http.HttpClient; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; +import com.azure.core.util.Context; +import com.azure.core.util.CoreUtils; +import com.azure.core.util.logging.ClientLogger; +import com.azure.identity.implementation.util.IdentitySslUtil; + +import reactor.core.publisher.Mono; + +public class CustomTokenProxyHttpClient implements HttpClient { + + private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyHttpClient.class); + + private final ProxyConfig proxyConfig; + private volatile SSLContext cachedSSLContext; + private volatile byte[] cachedFileContent; + private final URL proxyUrl; + private final String sniName; + private final byte[] caData; + private final String caFile; + + public CustomTokenProxyHttpClient(ProxyConfig proxyConfig) { + this.proxyConfig = proxyConfig; + this.proxyUrl = proxyConfig.getTokenProxyUrl(); + this.sniName = proxyConfig.getSniName(); + this.caData = proxyConfig.getCaData(); + this.caFile = proxyConfig.getCaFile(); + } + + @Override + public Mono send(HttpRequest request) { + return Mono.fromCallable(() -> sendSync(request, Context.NONE)); + } + + @Override + public HttpResponse sendSync(HttpRequest request, Context context) { + try { + HttpURLConnection connection = createConnection(request); + return new CustomTokenProxyHttpResponse(request, connection); + } catch (IOException e) { + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to create connection to token proxy", e)); + } + } + + private HttpURLConnection createConnection(HttpRequest request) throws IOException { + URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); + HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); + // If SNI explicitly provided + try { + SSLContext sslContext = getSSLContext(); + SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); + if (!CoreUtils.isNullOrEmpty(sniName)) { + sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, sniName); + } + connection.setSSLSocketFactory(sslSocketFactory); + connection.setHostnameVerifier(sniAwareVerifier(sniName, proxyUrl)); + } catch (Exception e) { + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to set up SSL context for token proxy", e)); + } + + String method = request.getHttpMethod().toString(); + connection.setRequestMethod(method); + connection.setInstanceFollowRedirects(false); + connection.setConnectTimeout(10_000); + connection.setReadTimeout(20_000); + + boolean hasBody = request.getBodyAsBinaryData() != null && request.getBodyAsBinaryData().toBytes() != null; + + if (hasBody && ("POST".equals(method) || "PUT".equals(method) || "PATCH".equals(method))) { + connection.setDoOutput(true); + } else { + connection.setDoOutput(false); + } + + request.getHeaders().forEach(header -> { + connection.addRequestProperty(header.getName(), header.getValue()); + }); + + if (request.getBodyAsBinaryData() != null) { + byte[] bytes = request.getBodyAsBinaryData().toBytes(); + if (bytes != null && bytes.length > 0) { + try (OutputStream os = connection.getOutputStream()) { + os.write(bytes); + os.flush(); + } + } + } + return connection; + } + + private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLException { + try { + String originalPath = originalUrl.getPath(); + String originalQuery = originalUrl.getQuery(); + + String tokenProxyBase = proxyUrl.toString(); + if (!tokenProxyBase.endsWith("/")) { + tokenProxyBase += "/"; + } + + URI combined = URI.create(tokenProxyBase) + .resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath); + + String combinedStr = combined.toString(); + if (originalQuery != null && !originalQuery.isEmpty()) { + combinedStr += "?" + originalQuery; + } + + return new URL(combinedStr); + + } catch (Exception e) { + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to rewrite token request for proxy", e)); + } + } + + private SSLContext getSSLContext() { + try { + // If no CA override provided, use default + if (CoreUtils.isNullOrEmpty(caFile) && (caData == null || caData.length == 0)) { + synchronized (this) { + if (cachedSSLContext == null) { + cachedSSLContext = SSLContext.getDefault(); + } + } + return cachedSSLContext; + } + + // If CA data provided, use it + if (CoreUtils.isNullOrEmpty(caFile)) { + synchronized (this) { + if (cachedSSLContext == null) { + cachedSSLContext = createSslContextFromBytes(caData); + } + } + return cachedSSLContext; + } + + // If CA file provided, read it (and re-read if it changes) + Path path = Paths.get(caFile); + if (!Files.exists(path)) { + throw LOGGER.logExceptionAsError(new RuntimeException("CA file not found: " + caFile)); + } + + byte[] currentContent = Files.readAllBytes(path); + + synchronized (this) { + if (currentContent.length == 0) { + throw LOGGER.logExceptionAsError(new RuntimeException("CA file " + caFile + " is empty")); + } + + if (cachedSSLContext == null || !Arrays.equals(currentContent, cachedFileContent)) { + cachedSSLContext = createSslContextFromBytes(currentContent); + cachedFileContent = currentContent; + } + } + return cachedSSLContext; + + } catch (Exception e) { + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to initialize SSLContext for proxy", e)); + } + } + + // Create SSLContext from byte array containing PEM certificate data + private SSLContext createSslContextFromBytes(byte[] certificateData) { + try (InputStream inputStream = new ByteArrayInputStream(certificateData)) { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + + List certificates = new ArrayList<>(); + while (true) { + try { + X509Certificate cert = (X509Certificate) cf.generateCertificate(inputStream); + certificates.add(cert); + } catch (CertificateException e) { + break; + } + } + + if (certificates.isEmpty()) { + throw LOGGER.logExceptionAsError(new RuntimeException("No valid certificates found")); + } + return createSslContext(certificates); + } catch (Exception e) { + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to create SSLContext from bytes", e)); + } + } + + // Create SSLContext from a single X509Certificate + private SSLContext createSslContext(List certificates) { + try { + KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType()); + keystore.load(null, null); + int index = 1; + for (X509Certificate caCert : certificates) { + keystore.setCertificateEntry("ca-cert-" + index, caCert); + index++; + } + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(keystore); + + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, tmf.getTrustManagers(), null); + return context; + } catch (Exception e) { + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to create SSLContext", e)); + } + } + + private static HostnameVerifier sniAwareVerifier(String sniName, URL customProxyUrl) { + return (urlHost, session) -> { + String peerHost = session.getPeerHost(); + String expectedProxyHost = customProxyUrl.getHost(); + return peerHost.equalsIgnoreCase(expectedProxyHost) + || (!CoreUtils.isNullOrEmpty(sniName) && peerHost.equalsIgnoreCase(sniName)); + }; + } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java new file mode 100644 index 000000000000..8eb14cea2764 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.identity.implementation.customtokenproxy; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import com.azure.core.http.HttpHeaderName; +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; +import com.azure.core.util.logging.ClientLogger; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public final class CustomTokenProxyHttpResponse extends HttpResponse { + + private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyHttpResponse.class); + + // private final HttpRequest request; + private final int statusCode; + private final HttpHeaders headers; + private final HttpURLConnection connection; + private byte[] cachedResponseBodyBytes; + + public CustomTokenProxyHttpResponse(HttpRequest request, HttpURLConnection connection) { + super(request); + this.connection = connection; + this.statusCode = extractStatusCode(connection); + this.headers = extractHeaders(connection); + } + + private HttpHeaders extractHeaders(HttpURLConnection connection) { + HttpHeaders headers = new HttpHeaders(); + for (Map.Entry> entry : connection.getHeaderFields().entrySet()) { + String headerName = entry.getKey(); + if (headerName != null) { + for (String headerValue : entry.getValue()) { + headers.add(headerName, headerValue); + } + } + } + return headers; + } + + private int extractStatusCode(HttpURLConnection connection) { + try { + return connection.getResponseCode(); + } catch (IOException e) { + throw LOGGER + .logExceptionAsError(new RuntimeException("Failed to get status code from token proxy response", e)); + } + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getHeaderValue(String name) { + return headers.getValue(HttpHeaderName.fromString(name)); + } + + @Override + public HttpHeaders getHeaders() { + return headers; + } + + @Override + public Mono getBodyAsByteArray() { + return Mono.fromCallable(() -> { + if (cachedResponseBodyBytes != null) { + return cachedResponseBodyBytes; + } + + InputStream stream = null; + try { + stream = getResponseStream(); + if (stream == null) { + cachedResponseBodyBytes = new byte[0]; + return cachedResponseBodyBytes; + } + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + int n; + byte[] temp = new byte[4096]; + while ((n = stream.read(temp)) != -1) { + buffer.write(temp, 0, n); + } + cachedResponseBodyBytes = buffer.toByteArray(); + return cachedResponseBodyBytes; + } finally { + if (stream != null) { + stream.close(); + } + } + }); + } + + @Override + public Flux getBody() { + return getBodyAsByteArray().flatMapMany(bytes -> Flux.just(ByteBuffer.wrap(bytes))); + } + + @Override + public Mono getBodyAsString() { + return getBodyAsString(StandardCharsets.UTF_8); + } + + @Override + public Mono getBodyAsString(Charset charset) { + return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); + } + + @Override + public void close() { + connection.disconnect(); + } + + private InputStream getResponseStream() throws IOException { + try { + return connection.getInputStream(); + } catch (IOException e) { + return connection.getErrorStream(); + } + } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java new file mode 100644 index 000000000000..171087bac6e9 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.identity.implementation.customtokenproxy; + +import java.net.URL; + +public class ProxyConfig { + private final URL tokenProxyUrl; + private final String sniName; + private final String caFile; + private final byte[] caData; + + public ProxyConfig(URL tokenProxyUrl, String sniName, String caFile, byte[] caData) { + this.tokenProxyUrl = tokenProxyUrl; + this.sniName = sniName; + this.caFile = caFile; + this.caData = caData; + } + + public URL getTokenProxyUrl() { + return tokenProxyUrl; + } + + public String getSniName() { + return sniName; + } + + public String getCaFile() { + return caFile; + } + + public byte[] getCaData() { + return caData; + } +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java index ba932d5e5a55..5f6b7161cd69 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java @@ -3,15 +3,23 @@ package com.azure.identity.implementation.util; +import com.azure.core.util.CoreUtils; import com.azure.core.util.logging.ClientLogger; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; + +import java.io.IOException; +import java.net.Socket; +import java.nio.charset.StandardCharsets; import java.security.KeyManagementException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -19,6 +27,7 @@ import java.security.cert.CertificateEncodingException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.Collections; public final class IdentitySslUtil { public static final HostnameVerifier ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER; @@ -125,4 +134,76 @@ private static String extractCertificateThumbprint(Certificate certificate, Clie throw logger.logExceptionAsError(new RuntimeException(e)); } } + + public static final class SniSslSocketFactory extends SSLSocketFactory { + private final SSLSocketFactory sslSocketFactory; + private final String sniName; + + public SniSslSocketFactory(SSLSocketFactory sslSocketFactory, String sniName) { + this.sslSocketFactory = sslSocketFactory; + this.sniName = sniName; + } + + @Override + public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException { + Socket sslSocket = (SSLSocket) sslSocketFactory.createSocket(s, host, port, autoClose); + configureSni(sslSocket); + return sslSocket; + } + + @Override + public Socket createSocket(String host, int port) throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(String host, int port, java.net.InetAddress localAddress, int localPort) + throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port, localAddress, localPort); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(java.net.InetAddress host, int port) throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, + int localPort) throws IOException { + Socket socket = sslSocketFactory.createSocket(address, port, localAddress, localPort); + configureSni(socket); + return socket; + } + + @Override + public String[] getDefaultCipherSuites() { + return sslSocketFactory.getDefaultCipherSuites(); + } + + @Override + public String[] getSupportedCipherSuites() { + return sslSocketFactory.getSupportedCipherSuites(); + } + + private void configureSni(Socket socket) { + if (socket instanceof SSLSocket && !CoreUtils.isNullOrEmpty(sniName)) { + SSLSocket sslSocket = (SSLSocket) socket; + SSLParameters sslParameters = sslSocket.getSSLParameters(); + sslParameters.setServerNames(Collections.singletonList(new RawSniServerName(sniName))); + sslSocket.setSSLParameters(sslParameters); + } + } + } + + public static final class RawSniServerName extends SNIServerName { + public RawSniServerName(String sniHost) { + super(0, sniHost.getBytes(StandardCharsets.UTF_8)); + } + } } diff --git a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialIdentityBindingTest.java b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialIdentityBindingTest.java new file mode 100644 index 000000000000..1c3cd5581efc --- /dev/null +++ b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialIdentityBindingTest.java @@ -0,0 +1,555 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.identity; + +import com.azure.core.credential.AccessToken; +import com.azure.core.credential.TokenRequestContext; +import com.azure.core.test.utils.TestConfigurationSource; +import com.azure.core.util.Configuration; +import com.azure.identity.util.TestUtils; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslHandler; +import io.netty.handler.ssl.SslProvider; +import io.netty.util.CharsetUtil; +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.X509v3CertificateBuilder; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import javax.net.ssl.SSLEngine; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.math.BigInteger; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.PrivateKey; +import java.security.Security; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.time.OffsetDateTime; +import java.util.Base64; +import java.util.Date; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * E2E test for WorkloadIdentityCredential with AKS Custom Token Proxy using Netty mock server. + * + */ +public class WorkloadIdentityCredentialIdentityBindingTest { + + private static final String AKS_SNI_NAME = "test-aks-proxy.ests.aks"; + + private static final String TEST_CLIENT_ID = "test-client-id"; + private static final String TEST_TENANT_ID = "test-tenant-id"; + private static final String MOCK_ACCESS_TOKEN = "mock_access_token_from_aks_proxy"; + + @TempDir + Path tempDir; + + private NioEventLoopGroup bossGroup; + private NioEventLoopGroup workerGroup; + private Channel serverChannel; + private String serverBaseUrl; + private Path tokenFilePath; + private Path caCertFilePath; + private String caCertPemData; + private AtomicInteger tokenRequestCount; + private AtomicInteger metadataRequestCount; + + @BeforeEach + public void setUp() throws Exception { + tokenRequestCount = new AtomicInteger(0); + metadataRequestCount = new AtomicInteger(0); + + // Generate key pair and certificate + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + KeyPair keyPair = keyPairGenerator.generateKeyPair(); + + // Generate self-signed certificate with both localhost and AKS SNI name + X509Certificate certificate = generateCertificateWithAksSni(keyPair, AKS_SNI_NAME); + + // Convert certificate to PEM format + caCertPemData = toPemFormat(certificate); + + // Write CA certificate to file + caCertFilePath = tempDir.resolve("ca-cert.pem"); + Files.write(caCertFilePath, caCertPemData.getBytes(StandardCharsets.UTF_8)); + + // Create mock federated token file + tokenFilePath = tempDir.resolve("token.jwt"); + String mockJwt = createMockFederatedToken(); + Files.write(tokenFilePath, mockJwt.getBytes(StandardCharsets.UTF_8)); + + // Start Netty HTTPS server + startNettyHttpsServer(keyPair.getPrivate(), certificate); + } + + @AfterEach + public void cleanup() { + if (serverChannel != null) { + serverChannel.close().syncUninterruptibly(); + } + if (bossGroup != null) { + bossGroup.shutdownGracefully(); + } + if (workerGroup != null) { + workerGroup.shutdownGracefully(); + } + } + + @Test + public void testAksProxyWithCaFile() { + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + AccessToken token = credential.getTokenSync(request); + + assertNotNull(token, "Access token should not be null"); + assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock response"); + assertTrue(token.getExpiresAt().isAfter(OffsetDateTime.now()), "Token should not be expired"); + assertEquals(1, tokenRequestCount.get(), "Server should have received exactly one token request"); + } + + @Test + public void testAksProxyWithCaData() { + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_DATA", caCertPemData) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + AccessToken token = credential.getTokenSync(request); + + assertNotNull(token, "Access token should not be null"); + assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock response"); + assertEquals(1, tokenRequestCount.get(), "Server should have received exactly one token request"); + } + + @Test + public void testAksProxyWithCaFileButNoSni() { + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + AccessToken token = credential.getTokenSync(request); + + assertNotNull(token, "Access token should not be null"); + assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock response"); + assertEquals(1, tokenRequestCount.get(), "Server should have received exactly one token request"); + } + + @Test + public void testAksProxyWithInvalidTokenFile() { + Path nonExistentTokenFile = tempDir.resolve("non-existent-token.jwt"); + + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", nonExistentTokenFile.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(nonExistentTokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + Exception exception = assertThrows(Exception.class, () -> credential.getTokenSync(request)); + + assertNotNull(exception, "Should throw exception when token file doesn't exist"); + assertEquals(0, tokenRequestCount.get(), "No token request should reach the server with invalid token file"); + } + + @Test + public void testAksProxyWithInvalidCaCertificate() throws Exception { + String invalidCertData = "-----BEGIN CERTIFICATE-----\nINVALID_BASE64_DATA\n-----END CERTIFICATE-----"; + Path invalidCaCertFile = tempDir.resolve("invalid-ca.pem"); + Files.write(invalidCaCertFile, invalidCertData.getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_FILE", invalidCaCertFile.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + Exception exception = assertThrows(Exception.class, () -> { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + credential.getTokenSync(request); + }); + + assertNotNull(exception, "Should throw exception when CA certificate is invalid"); + assertEquals(0, tokenRequestCount.get(), "No token request should succeed with invalid CA certificate"); + } + + @Test + public void testAksProxyWithHttpScheme() { + String httpProxyUrl = serverBaseUrl.replace("https://", "http://"); + + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", httpProxyUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + Exception exception = assertThrows(Exception.class, () -> { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + credential.getTokenSync(request); + }); + + assertNotNull(exception, "Should throw exception when proxy URL uses HTTP instead of HTTPS"); + assertEquals(0, tokenRequestCount.get(), "No token request should be made with HTTP proxy URL"); + } + + @Test + public void testAksProxyWithMalformedUrl() { + String malformedUrl = "not-a-valid-url-at-all"; + + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", malformedUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + Exception exception = assertThrows(Exception.class, () -> { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + credential.getTokenSync(request); + }); + + assertNotNull(exception, "Should throw exception when proxy URL is malformed"); + assertEquals(0, tokenRequestCount.get(), "No token request should be made with malformed proxy URL"); + } + + @Test + public void testAksProxyUnreachable() { + String unreachableProxyUrl = "https://localhost:19999"; + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", unreachableProxyUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + Exception exception = assertThrows(Exception.class, () -> credential.getTokenSync(request)); + + assertNotNull(exception, "Should throw exception when proxy server is unreachable"); + assertEquals(0, tokenRequestCount.get(), "No token request should succeed when proxy is unreachable"); + } + + @Test + public void testAksProxyWithEmptyTokenFile() throws Exception { + Path emptyTokenFile = tempDir.resolve("empty-token.jwt"); + Files.write(emptyTokenFile, new byte[0]); + + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", emptyTokenFile.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(emptyTokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + Exception exception = assertThrows(Exception.class, () -> credential.getTokenSync(request)); + + assertNotNull(exception, "Should throw exception when token file is empty"); + assertEquals(0, tokenRequestCount.get(), "No token request should be made with empty token file"); + } + + @Test + public void testAksProxyWithUrlEncodedCharactersInPath() throws Exception { + String encodedPath = "/api%2Fv1/token"; + String proxyUrlWithEncoding = serverBaseUrl + encodedPath; + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", proxyUrlWithEncoding) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + AccessToken token = credential.getTokenSync(request); + + assertNotNull(token, "Token should not be null"); + assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock token"); + assertTrue(token.getExpiresAt().isAfter(OffsetDateTime.now()), "Token should not be expired"); + assertEquals(1, tokenRequestCount.get(), "Token request should be made once"); + } + + private X509Certificate generateCertificateWithAksSni(KeyPair keyPair, String aksSniName) throws Exception { + if (Security.getProvider("BC") == null) { + Security.addProvider(new BouncyCastleProvider()); + } + + long now = System.currentTimeMillis(); + Date notBefore = new Date(now - 1000L * 60 * 60); + Date notAfter = new Date(now + 1000L * 60 * 60 * 24 * 365); + BigInteger serial = BigInteger.valueOf(now); + + X500Name subject = new X500Name("CN=AKS-Proxy-Test"); + SubjectPublicKeyInfo publicKeyInfo = SubjectPublicKeyInfo.getInstance(keyPair.getPublic().getEncoded()); + + X509v3CertificateBuilder certBuilder + = new X509v3CertificateBuilder(subject, serial, notBefore, notAfter, subject, publicKeyInfo); + + GeneralName[] sans = new GeneralName[] { + new GeneralName(GeneralName.dNSName, "localhost"), + new GeneralName(GeneralName.dNSName, new org.bouncycastle.asn1.DERIA5String(aksSniName)) }; + GeneralNames subjectAltNames = new GeneralNames(sans); + certBuilder.addExtension(Extension.subjectAlternativeName, false, subjectAltNames); + + ContentSigner signer + = new JcaContentSignerBuilder("SHA256withRSA").setProvider("BC").build(keyPair.getPrivate()); + + X509CertificateHolder certHolder = certBuilder.build(signer); + + CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); + try (InputStream in = new ByteArrayInputStream(certHolder.getEncoded())) { + return (X509Certificate) certFactory.generateCertificate(in); + } + } + + private void startNettyHttpsServer(PrivateKey privateKey, X509Certificate certificate) throws Exception { + SslContext sslContext + = SslContextBuilder.forServer(privateKey, certificate).sslProvider(SslProvider.OPENSSL).build(); + + bossGroup = new NioEventLoopGroup(1); + workerGroup = new NioEventLoopGroup(); + + ServerBootstrap bootstrap = new ServerBootstrap(); + bootstrap.group(bossGroup, workerGroup) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ChannelPipeline pipeline = ch.pipeline(); + + SslHandler sslHandler = sslContext.newHandler(ch.alloc()); + + sslHandler.handshakeFuture().addListener(future -> { + if (future.isSuccess()) { + SSLEngine engine = sslHandler.engine(); + System.out.println("OpenSSL handshake successful."); + } + }); + + pipeline.addLast(sslHandler); + pipeline.addLast(new HttpServerCodec()); + pipeline.addLast(new HttpObjectAggregator(65536)); + pipeline.addLast(new HttpRequestHandler()); + } + }); + + serverChannel = bootstrap.bind("localhost", 0).sync().channel(); + InetSocketAddress socketAddress = (InetSocketAddress) serverChannel.localAddress(); + int port = socketAddress.getPort(); + serverBaseUrl = "https://localhost:" + port; + + System.out.println("Netty HTTPS server (OpenSSL) started at: " + serverBaseUrl); + } + + private class HttpRequestHandler extends SimpleChannelInboundHandler { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) { + String uri = request.uri(); + String responseBody; + + if (uri.contains("/.well-known/openid-configuration")) { + metadataRequestCount.incrementAndGet(); + String base = serverBaseUrl + "/" + TEST_TENANT_ID; + responseBody = String + .format("{%n" + " \"token_endpoint\": \"%s/oauth2/v2.0/token\",%n" + " \"issuer\": \"%s/v2.0\",%n" + + " \"authorization_endpoint\": \"%s/oauth2/v2.0/authorize\"%n" + "}", base, base, base); + } else { + tokenRequestCount.incrementAndGet(); + responseBody = String.format( + "{\"token_type\":\"Bearer\",\"expires_in\":3600,\"ext_expires_in\":3600,\"access_token\":\"%s\"}", + MOCK_ACCESS_TOKEN); + } + + ByteBuf content = Unpooled.copiedBuffer(responseBody, CharsetUtil.UTF_8); + FullHttpResponse response + = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, content); + response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json"); + response.headers().set(HttpHeaderNames.CONTENT_LENGTH, content.readableBytes()); + + ctx.writeAndFlush(response); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ctx.close(); + } + } + + private String toPemFormat(X509Certificate certificate) throws Exception { + Base64.Encoder encoder = Base64.getMimeEncoder(64, "\n".getBytes(StandardCharsets.UTF_8)); + byte[] encoded = encoder.encode(certificate.getEncoded()); + return "-----BEGIN CERTIFICATE-----\n" + new String(encoded, StandardCharsets.UTF_8) + + "\n-----END CERTIFICATE-----\n"; + } + + private String createMockFederatedToken() { + String header = Base64.getUrlEncoder() + .withoutPadding() + .encodeToString("{\"alg\":\"RS256\",\"typ\":\"JWT\"}".getBytes(StandardCharsets.UTF_8)); + + long exp = System.currentTimeMillis() / 1000 + 3600; + String payload = Base64.getUrlEncoder() + .withoutPadding() + .encodeToString(String.format( + "{\"aud\":\"api://AzureADTokenExchange\",\"exp\":%d,\"iss\":\"kubernetes.io/serviceaccount\",\"sub\":\"system:serviceaccount:default:workload-identity-sa\"}", + exp).getBytes(StandardCharsets.UTF_8)); + + String signature + = Base64.getUrlEncoder().withoutPadding().encodeToString("mock-signature".getBytes(StandardCharsets.UTF_8)); + + return header + "." + payload + "." + signature; + } +} diff --git a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java index d2404d9b302b..e52c3ef8735f 100644 --- a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java +++ b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java @@ -35,6 +35,10 @@ public class WorkloadIdentityCredentialTest { private static final String CLIENT_ID = UUID.randomUUID().toString(); + private static final String ENV_PROXY_URL = "AZURE_KUBERNETES_TOKEN_PROXY"; + private static final String ENV_CA_FILE = "AZURE_KUBERNETES_CA_FILE"; + private static final String ENV_CA_DATA = "AZURE_KUBERNETES_CA_DATA"; + private static final String ENV_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME"; @Test public void testWorkloadIdentityFlow(@TempDir Path tempDir) throws IOException { @@ -192,4 +196,342 @@ public void testFileReadingError(@TempDir Path tempDir) { assertTrue(error.getCause() instanceof IOException); // Original IOException from Files.readAllBytes }).verify(); } + + @Test + public void testProxyEnabledWithProxyUrlGetsToken(@TempDir Path tempDir) throws IOException { + // setup + String endpoint = "https://localhost"; + String token1 = "token1"; + String proxyUrl = "https://token-proxy.example.com"; + + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, proxyUrl)); + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + try (MockedConstruction identityClientMock + = mockConstruction(IdentityClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty()); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt)); + })) { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + StepVerifier.create(credential.getToken(request1)) + .expectNextMatches(token -> token1.equals(token.getToken()) + && expiresAt.getSecond() == token.getExpiresAt().getSecond()) + .verifyComplete(); + + assertNotNull(identityClientMock); + } + } + + @Test + public void testProxyEnabledWithoutProxyUrlGetsToken(@TempDir Path tempDir) throws IOException { + // setup + String endpoint = "https://localhost"; + String token1 = "token1"; + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)); + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + try (MockedConstruction identityClientMock + = mockConstruction(IdentityClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty()); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt)); + })) { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + StepVerifier.create(credential.getToken(request1)) + .expectNextMatches(token -> token1.equals(token.getToken()) + && expiresAt.getSecond() == token.getExpiresAt().getSecond()) + .verifyComplete(); + + assertNotNull(identityClientMock); + } + } + + @Test + public void testProxyEnabledInvalidProxyUrlSchemeFailure(@TempDir Path tempDir) throws IOException { + String endpoint = "https://localhost"; + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "http://not-https.example.com")); + + Assertions.assertThrows(IllegalArgumentException.class, () -> { + new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + }); + } + + @Test + public void testProxyUrlWithQueryFailure(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://proxy.example.com?x=y")); + + Assertions.assertThrows(IllegalArgumentException.class, () -> { + new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + }); + } + + @Test + public void testProxyUrlWithFragmentFailure(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://proxy.example.com#frag")); + + Assertions.assertThrows(IllegalArgumentException.class, () -> { + new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + }); + } + + @Test + public void testProxyUrlWithUserInfoFailure(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://user:pass@proxy.example.com")); + + Assertions.assertThrows(IllegalArgumentException.class, () -> { + new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + }); + } + + @Test + public void testCaFileAndCaDataPresentFailure(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + Path caFile = tempDir.resolve("ca.crt"); + Files.write(caFile, + "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----\n".getBytes(StandardCharsets.UTF_8)); + + String caData = "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----"; + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://proxy.example.com") + .put(ENV_CA_FILE, caFile.toString()) + .put(ENV_CA_DATA, caData)); + + Assertions.assertThrows(IllegalArgumentException.class, () -> { + new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + }); + } + + @Test + public void testProxyEnabledWithProxyUrlGetsTokenSync(@TempDir Path tempDir) throws IOException { + // setup + String endpoint = "https://localhost"; + String token1 = "token1"; + String proxyUrl = "https://token-proxy.example.com"; + + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, proxyUrl)); + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + try (MockedConstruction identitySyncClientMock + = mockConstruction(IdentitySyncClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())) + .thenThrow(new IllegalStateException("Test")); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessTokenSync(token1, expiresAt)); + })) { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + AccessToken token = credential.getTokenSync(request1); + + assertTrue(token1.equals(token.getToken())); + assertTrue(expiresAt.getSecond() == token.getExpiresAt().getSecond()); + assertNotNull(identitySyncClientMock); + } + } + + @Test + public void testProxyUrlWithCaDataAcquiresToken(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + String token1 = "token-ca-data"; + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + + String caData = "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----"; + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://token-proxy.example.com") + .put(ENV_CA_DATA, caData)); + + try (MockedConstruction mocked + = mockConstruction(IdentityClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty()); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt)); + })) { + WorkloadIdentityCredential cred = new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + StepVerifier.create(cred.getToken(request1)) + .expectNextMatches(token -> token1.equals(token.getToken()) + && token.getExpiresAt().getSecond() == expiresAt.getSecond()) + .verifyComplete(); + + assertNotNull(mocked); + } + } + + @Test + public void testProxyUrlWithCaFileGetsToken(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + String token1 = "tok-ca-file"; + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + Path caFile = tempDir.resolve("ca.crt"); + Files.write(caFile, + "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----\n".getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://token-proxy.example.com") + .put(ENV_CA_FILE, caFile.toString())); + + try (MockedConstruction mocked + = mockConstruction(IdentityClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty()); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt)); + })) { + WorkloadIdentityCredential cred = new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + StepVerifier.create(cred.getToken(request1)) + .expectNextMatches(token -> token1.equals(token.getToken()) + && token.getExpiresAt().getSecond() == expiresAt.getSecond()) + .verifyComplete(); + + assertNotNull(mocked); + } + } + + @Test + public void testProxyEnabledWithSniNameGetsToken(@TempDir Path tempDir) throws IOException { + // setup + String endpoint = "https://localhost"; + String token1 = "token1"; + String proxyUrl = "https://token-proxy.example.com"; + String sniName = "615f3b8ad7eb011a09ed3b762e404de43ebc7ade0802a34c9fd322b688c3a655.ests.aks"; + + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, proxyUrl) + .put(ENV_SNI_NAME, sniName)); + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + try (MockedConstruction identityClientMock + = mockConstruction(IdentityClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty()); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt)); + })) { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + StepVerifier.create(credential.getToken(request1)) + .expectNextMatches(token -> token1.equals(token.getToken()) + && expiresAt.getSecond() == token.getExpiresAt().getSecond()) + .verifyComplete(); + + assertNotNull(identityClientMock); + } + } + }