From b7ca0dcda6560a5d799022fa403672f7aa963aef Mon Sep 17 00:00:00 2001 From: anannya03 Date: Sat, 18 Oct 2025 00:39:13 -0700 Subject: [PATCH 1/7] changes to add proxy in wic --- .../CustomTokenProxyConfiguration.java | 119 ++++++++++ .../CustomTokenProxyHttpClient.java | 218 ++++++++++++++++++ .../CustomTokenProxyHttpResponse.java | 154 +++++++++++++ .../customtokenproxy/ProxyConfig.java | 33 +++ .../implementation/util/IdentitySslUtil.java | 76 ++++++ 5 files changed, 600 insertions(+) create mode 100644 sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java create mode 100644 sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java create mode 100644 sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java create mode 100644 sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java new file mode 100644 index 000000000000..b0849d29bf56 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -0,0 +1,119 @@ +package com.azure.identity.implementation.customtokenproxy; + +import com.azure.core.util.logging.ClientLogger; +import com.azure.identity.implementation.WorkloadIdentityTokenProxyPolicy; + +import java.net.URI; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.net.URISyntaxException; + +import com.azure.core.util.Configuration; +import com.azure.core.util.CoreUtils; + +public class CustomTokenProxyConfiguration { + + private static final ClientLogger LOGGER = new ClientLogger(WorkloadIdentityTokenProxyPolicy.class); + + public static final String AZURE_KUBERNETES_TOKEN_PROXY = "AZURE_KUBERNETES_TOKEN_PROXY"; + public static final String AZURE_KUBERNETES_CA_FILE = "AZURE_KUBERNETES_CA_FILE"; + public static final String AZURE_KUBERNETES_CA_DATA = "AZURE_KUBERNETES_CA_DATA"; + public static final String AZURE_KUBERNETES_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME"; + + private CustomTokenProxyConfiguration() {} + + public static boolean isConfigured(Configuration configuration) { + String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY); + return !CoreUtils.isNullOrEmpty(tokenProxyUrl); + } + + public static ProxyConfig parseAndValidate(Configuration configuration) { + String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY); + String caFile = configuration.get(AZURE_KUBERNETES_CA_FILE); + String caData = configuration.get(AZURE_KUBERNETES_CA_DATA); + String sniName = configuration.get(AZURE_KUBERNETES_SNI_NAME); + + if (CoreUtils.isNullOrEmpty(tokenProxyUrl)) { + if (!CoreUtils.isNullOrEmpty(sniName) + || !CoreUtils.isNullOrEmpty(caFile) + || !CoreUtils.isNullOrEmpty(caData)) { + throw LOGGER.logExceptionAsError(new IllegalStateException( + "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present")); + } + throw LOGGER.logExceptionAsError(new IllegalStateException( + "AZURE_KUBERNETES_TOKEN_PROXY must be set to enable custom token proxy.")); + } + + if (!CoreUtils.isNullOrEmpty(caFile) && !CoreUtils.isNullOrEmpty(caData)) { + throw LOGGER.logExceptionAsError(new IllegalStateException( + "Only one of AZURE_KUBERNETES_CA_FILE or AZURE_KUBERNETES_CA_DATA can be set.")); + } + + URL proxyUrl = validateProxyUrl(tokenProxyUrl); + + byte[] caCertBytes = null; + if(!CoreUtils.isNullOrEmpty(caData)) { + try { + caCertBytes = caData.getBytes(StandardCharsets.UTF_8); + } catch (Exception e) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Failed to decode CA certificate data from AZURE_KUBERNETES_CA_DATA", e)); + } + } + + return new ProxyConfig(proxyUrl, sniName, caFile, caCertBytes); + } + + private static URL validateProxyUrl(String endpoint) { + if (CoreUtils.isNullOrEmpty(endpoint)) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Proxy endpoint cannot be null or empty")); + } + + try { + URI tokenProxy = new URI(endpoint); + + if (!"https".equals(tokenProxy.getScheme())) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Custom token endpoint must use https scheme, got: " + tokenProxy.getScheme())); + } + + if (tokenProxy.getRawUserInfo() != null) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Custom token endpoint URL must not contain user info: " + endpoint)); + } + + if (tokenProxy.getRawQuery() != null) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Custom token endpoint URL must not contain a query: " + endpoint)); + } + + if (tokenProxy.getRawFragment() != null) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException( + "Custom token endpoint URL must not contain a fragment: " + endpoint)); + } + + if (tokenProxy.getRawPath() == null || tokenProxy.getRawPath().isEmpty()) { + tokenProxy = new URI(tokenProxy.getScheme(), null, tokenProxy.getHost(), + tokenProxy.getPort(), "/", null, null); + } + + return tokenProxy.toURL(); + + } catch (URISyntaxException | IllegalArgumentException e) { + throw LOGGER.logExceptionAsError(new IllegalArgumentException("Failed to normalize proxy URL path", e)); + } catch (Exception e) { + throw new RuntimeException("Unexpected error while validating proxy URL: " + endpoint, e); + } + } + + // public static String getTokenProxyUrl() { + // String tokenProxyUrl = System.getenv(AZURE_KUBERNETES_TOKEN_PROXY); + // if (tokenProxyUrl == null || tokenProxyUrl.isEmpty()) { + // throw LOGGER.logExceptionAsError(new IllegalStateException( + // String.format("Environment variable '%s' is not set or is empty. It must be set to the URL of the" + // + " token proxy.", AZURE_KUBERNETES_TOKEN_PROXY))); + // } + // return tokenProxyUrl; + // } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java new file mode 100644 index 000000000000..194d306eef5b --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -0,0 +1,218 @@ +package com.azure.identity.implementation.customtokenproxy; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.MalformedParametersException; +import java.lang.reflect.Proxy; +import java.net.HttpURLConnection; +import java.net.URI; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyStore; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Arrays; + +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManagerFactory; + +import com.azure.core.http.HttpClient; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; +import com.azure.core.util.Context; +import com.azure.core.util.CoreUtils; +import com.azure.identity.implementation.util.IdentitySslUtil; + +import reactor.core.publisher.Mono; + +public class CustomTokenProxyHttpClient implements HttpClient { + + private final ProxyConfig proxyConfig; + private volatile SSLContext cachedSSLContext; + private volatile byte[] cachedFileContent; + + public CustomTokenProxyHttpClient(ProxyConfig proxyConfig) { + this.proxyConfig = proxyConfig; + } + + @Override + public Mono send(HttpRequest request) { + return Mono.fromCallable(() -> sendSync(request, Context.NONE)); + } + + @Override + public HttpResponse sendSync(HttpRequest request, Context context) { + try { + HttpURLConnection connection = createConnection(request); + connection.connect(); + return new CustomTokenProxyHttpResponse(request, connection); + } catch (IOException e) { + throw new RuntimeException("Failed to create connection to token proxy", e); + } + } + + // private HttpURLConnection createConnection(HttpRequest request) throws IOException { + // URL updateProxyRequest = rewriteTokenRequestForProxy(request.getUrl()); + // HttpsURLConnection connection = (HttpsURLConnection) updateProxyRequest.openConnection(); + // try { + // SSLContext sslContext = getSSLContext(); + // connection.setSSLSocketFactory(sslContext.getSocketFactory()); + + // if(!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { + // SSLParameters sslParameters = connection.getSSLParameters(); + // } + // } + // return connection; + + // // connection.setRequestMethod(request.getMethod().toString()); + // // connection.setDoOutput(true); + // // request.getHeaders().forEach((key, values) -> { + // // values.forEach(value -> connection.addRequestProperty(key, value)); + // // }); + // // return connection; + // } + + private HttpURLConnection createConnection(HttpRequest request) throws IOException { + URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); + HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); + + // If SNI explicitly provided + try { + SSLContext sslContext = getSSLContext(); + SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); + if(!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { + sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, proxyConfig.getSniName()); + } + connection.setSSLSocketFactory(sslSocketFactory); + } catch (Exception e) { + throw new RuntimeException("Failed to set up SSL context for token proxy", e); + } + + connection.setRequestMethod(request.getHttpMethod().toString()); + // connection.setConnectTimeout(10_000); + // connection.setReadTimeout(20_000); + connection.setDoOutput(true); + + request.getHeaders().forEach(header -> { + connection.addRequestProperty(header.getName(), header.getValue()); + }); + + if (request.getBodyAsBinaryData() != null) { + byte[] bytes = request.getBodyAsBinaryData().toBytes(); + if (bytes != null && bytes.length > 0) { + connection.getOutputStream().write(bytes); + } + } + + return connection; + } + + + private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedParametersException{ + try { + String originalPath = originalUrl.getPath(); + String originalQuery = originalUrl.getQuery(); + + String tokenProxyBase = proxyConfig.getTokenProxyUrl().toString(); + if(!tokenProxyBase.endsWith("/")) tokenProxyBase += "/"; + + URI combined = URI.create(tokenProxyBase).resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath); + + String combinedStr = combined.toString(); + if (originalQuery != null && !originalQuery.isEmpty()) { + combinedStr += "?" + originalQuery; + } + + return new URL(combinedStr); + + } catch (Exception e) { + throw new RuntimeException("Failed to rewrite token request for proxy", e); + } + } + + private SSLContext getSSLContext() { + try { + // If no CA override provide, use default + if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile()) + && (proxyConfig.getCaData() == null || proxyConfig.getCaData().length == 0)) { + synchronized (this) { + if (cachedSSLContext == null) { + cachedSSLContext = SSLContext.getDefault(); + } + } + return cachedSSLContext; + } + + // If CA data provided, use it + if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile())) { + synchronized (this) { + if (cachedSSLContext == null) { + cachedSSLContext = createSslContextFromBytes(proxyConfig.getCaData()); + } + } + return cachedSSLContext; + } + + // If CA file provided, read it (and re-read if it changes) + Path path = Paths.get(proxyConfig.getCaFile()); + if (!Files.exists(path)) { + throw new IOException("CA file not found: " + proxyConfig.getCaFile()); + } + + byte[] currentContent; + + synchronized (this) { + currentContent = Files.readAllBytes(path); + if (currentContent.length == 0) { + throw new IOException("CA file " + proxyConfig.getCaFile() + " is empty"); + } + + if (cachedSSLContext == null || !Arrays.equals(currentContent, cachedFileContent)) { + cachedSSLContext = createSslContextFromBytes(currentContent); + cachedFileContent = currentContent; + } + } + + return cachedSSLContext; + + } catch (Exception e) { + throw new RuntimeException("Failed to create default SSLContext", e); + } + } + + // Create SSLContext from byte array containing PEM certificate data + private SSLContext createSslContextFromBytes(byte[] certificateData) { + try (InputStream inputStream = new ByteArrayInputStream(certificateData)) { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate caCert = (X509Certificate) cf.generateCertificate(inputStream); + return createSslContext(caCert); + } catch (Exception e) { + throw new RuntimeException("Failed to create SSLContext from bytes", e); + } + } + + // Create SSLContext from a single X509Certificate + private SSLContext createSslContext(X509Certificate caCert) { + try { + KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType()); + keystore.load(null, null); + keystore.setCertificateEntry("ca-cert", caCert); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(keystore); + + SSLContext context = SSLContext.getInstance("TLS"); + context.init(null, tmf.getTrustManagers(), null); + return context; + } catch (Exception e) { + throw new RuntimeException("Failed to create SSLContext", e); + } + } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java new file mode 100644 index 000000000000..33d148c86927 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -0,0 +1,154 @@ +package com.azure.identity.implementation.customtokenproxy; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import com.azure.core.http.HttpHeaderName; +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.HttpRequest; +import com.azure.core.http.HttpResponse; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class CustomTokenProxyHttpResponse extends HttpResponse { + + // private final HttpRequest request; + private final int statusCode; + private final HttpHeaders headers; + private final HttpURLConnection connection; + private byte[] cachedRequestBodyBytes; + + public CustomTokenProxyHttpResponse(HttpRequest request, HttpURLConnection connection) { + super(request); + this.connection = connection; + this.statusCode = extractStatusCode(connection); + this.headers = extractHeaders(connection); + } + + private HttpHeaders extractHeaders(HttpURLConnection connection) { + HttpHeaders headers = new HttpHeaders(); + for (Map.Entry> entry : connection.getHeaderFields().entrySet()) { + String headerName = entry.getKey(); + if (headerName != null) { + for (String headerValue : entry.getValue()) { + headers.add(headerName, headerValue); + } + } + } + return headers; + } + + public int extractStatusCode(HttpURLConnection connection) { + try { + return connection.getResponseCode(); + } catch (IOException e) { + throw new RuntimeException("Failed to get status code from token proxy response", e); + } + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getHeaderValue(String name) { + return headers.getValue(HttpHeaderName.fromString(name)); + } + + @Override + public HttpHeaders getHeaders() { + return headers; + } + + // @Override + // public Mono getBodyAsByteArray() { + // return Mono.fromCallable(() -> { + // try (InputStream inputStream = connection.getInputStream()) { + // return inputStream.readAllBytes(); + // } catch (IOException e) { + // throw new RuntimeException("Failed to read body from token proxy response", e); + // } + // }); + // } + + @Override + public Mono getBodyAsByteArray() { + return Mono.fromCallable(() -> { + if (cachedRequestBodyBytes != null) { + return cachedRequestBodyBytes; + } + try (InputStream stream = getResponseStream()) { + if (stream == null) { + cachedRequestBodyBytes = new byte[0]; + return cachedRequestBodyBytes; + } + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + byte[] tmp = new byte[4096]; + int n; + while ((n = stream.read(tmp)) != -1) { + buffer.write(tmp, 0, n); + } + cachedRequestBodyBytes = buffer.toByteArray(); + return cachedRequestBodyBytes; + } + }); + } + + @Override + public Flux getBody() { + return getBodyAsByteArray().flatMapMany(bytes -> Flux.just(ByteBuffer.wrap(bytes))); + } + + @Override + public Mono getBodyAsString() { + return getBodyAsString(StandardCharsets.UTF_8); + } + + @Override + public Mono getBodyAsString(Charset charset) { + return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); + } + + + @Override + public void close() { + connection.disconnect(); + } + + + + // @Override + // public Flux getBody() { + // // TODO Auto-generated method stub + // throw new UnsupportedOperationException("Unimplemented method 'getBody'"); + // } + + // @Override + // public Mono getBodyAsString() { + // return getBodyAsByteArray().map(bytes -> new String(bytes, StandardCharsets.UTF_8)); + // } + + // @Override + // public Mono getBodyAsString(Charset charset) { + // return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); + // } + + private InputStream getResponseStream() throws IOException { + try { + return connection.getInputStream(); + } catch (IOException e) { + // On non-2xx responses, getInputStream() throws, use error stream instead + return connection.getErrorStream(); + } + } + +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java new file mode 100644 index 000000000000..a104a8ab7053 --- /dev/null +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java @@ -0,0 +1,33 @@ +package com.azure.identity.implementation.customtokenproxy; + +import java.net.URL; + +public class ProxyConfig { + private final URL tokenProxyUrl; + private final String sniName; + private final String caFile; + private final byte[] caData; + + public ProxyConfig(URL tokenProxyUrl, String sniName, String caFile, byte[] caData) { + this.tokenProxyUrl = tokenProxyUrl; + this.sniName = sniName; + this.caFile = caFile; + this.caData = caData; + } + + public URL getTokenProxyUrl() { + return tokenProxyUrl; + } + + public String getSniName() { + return sniName; + } + + public String getCaFile() { + return caFile; + } + + public byte[] getCaData() { + return caData; + } +} diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java index ba932d5e5a55..604dba0b59bf 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java @@ -3,15 +3,23 @@ package com.azure.identity.implementation.util; +import com.azure.core.util.CoreUtils; import com.azure.core.util.logging.ClientLogger; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; + +import java.io.IOException; +import java.net.Socket; import java.security.KeyManagementException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -19,6 +27,9 @@ import java.security.cert.CertificateEncodingException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; public final class IdentitySslUtil { public static final HostnameVerifier ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER; @@ -125,4 +136,69 @@ private static String extractCertificateThumbprint(Certificate certificate, Clie throw logger.logExceptionAsError(new RuntimeException(e)); } } + + public static final class SniSslSocketFactory extends SSLSocketFactory { + private final SSLSocketFactory sslSocketFactory; + private final String sniName; + + public SniSslSocketFactory(SSLSocketFactory sslSocketFactory, String sniName) { + this.sslSocketFactory = sslSocketFactory; + this.sniName = sniName; + } + + @Override + public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException { + Socket sslSocket = (SSLSocket) sslSocketFactory.createSocket(s, host, port, autoClose); + configureSni(sslSocket); + return sslSocket; + } + + @Override + public Socket createSocket(String host, int port) throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(String host, int port, java.net.InetAddress localAddress, int localPort) throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port, localAddress, localPort); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(java.net.InetAddress host, int port) throws IOException { + Socket socket = sslSocketFactory.createSocket(host, port); + configureSni(socket); + return socket; + } + + @Override + public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, int localPort) throws IOException { + Socket socket = sslSocketFactory.createSocket(address, port, localAddress, localPort); + configureSni(socket); + return socket; + } + + @Override + public String[] getDefaultCipherSuites() { + return sslSocketFactory.getDefaultCipherSuites(); + } + + @Override + public String[] getSupportedCipherSuites() { + return sslSocketFactory.getSupportedCipherSuites(); + } + + private void configureSni(Socket socket) { + if (socket instanceof SSLSocket && !CoreUtils.isNullOrEmpty(sniName)) { + SSLSocket sslSocket = (SSLSocket) socket; + SSLParameters sslParameters = sslSocket.getSSLParameters(); + sslParameters.setServerNames(Collections.singletonList(new SNIHostName(sniName))); + sslSocket.setSSLParameters(sslParameters); + } + } + + } } From b162676ee798f0302b67cf957326b0b425fab047 Mon Sep 17 00:00:00 2001 From: anannya03 Date: Sun, 19 Oct 2025 02:07:12 -0700 Subject: [PATCH 2/7] integration with wicbuilder and wic --- .../identity/WorkloadIdentityCredential.java | 11 +++ .../WorkloadIdentityCredentialBuilder.java | 8 +++ .../implementation/IdentityClientOptions.java | 15 +++- .../CustomTokenProxyConfiguration.java | 6 +- .../CustomTokenProxyHttpClient.java | 72 ++++++++++--------- .../CustomTokenProxyHttpResponse.java | 32 ++------- 6 files changed, 79 insertions(+), 65 deletions(-) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java index a0c04ff2d952..3fedf5db30d9 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java @@ -10,6 +10,9 @@ import com.azure.core.util.CoreUtils; import com.azure.core.util.logging.ClientLogger; import com.azure.identity.implementation.IdentityClientOptions; +import com.azure.identity.implementation.customtokenproxy.CustomTokenProxyConfiguration; +import com.azure.identity.implementation.customtokenproxy.CustomTokenProxyHttpClient; +import com.azure.identity.implementation.customtokenproxy.ProxyConfig; import com.azure.identity.implementation.util.LoggingUtil; import com.azure.identity.implementation.util.ValidationUtil; import reactor.core.publisher.Mono; @@ -89,6 +92,14 @@ public class WorkloadIdentityCredential implements TokenCredential { ClientAssertionCredential tempClientAssertionCredential = null; String tempClientId = null; + if(identityClientOptions.isKubernetesTokenProxyEnabled()) { + if (!CustomTokenProxyConfiguration.isConfigured(configuration)) { + throw LOGGER.logExceptionAsError (new IllegalArgumentException("Kubernetes token proxy is enabled but not configured.")); + } + ProxyConfig proxyConfig = CustomTokenProxyConfiguration.parseAndValidate(configuration); + identityClientOptions.setHttpClient(new CustomTokenProxyHttpClient(proxyConfig)); + } + if (!(CoreUtils.isNullOrEmpty(tenantIdInput) || CoreUtils.isNullOrEmpty(federatedTokenFilePathInput) || CoreUtils.isNullOrEmpty(clientIdInput) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java index 95f16fd9628d..5520ad122bf4 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java @@ -47,6 +47,7 @@ public class WorkloadIdentityCredentialBuilder extends AadCredentialBuilderBase { private static final ClientLogger LOGGER = new ClientLogger(WorkloadIdentityCredentialBuilder.class); private String tokenFilePath; + private boolean enableTokenProxy = false; /** * Creates an instance of a WorkloadIdentityCredentialBuilder. @@ -66,6 +67,11 @@ public WorkloadIdentityCredentialBuilder tokenFilePath(String tokenFilePath) { return this; } + public WorkloadIdentityCredentialBuilder enableKubernetesTokenProxy(boolean enable) { + this.enableTokenProxy = enable; + return this; + } + /** * Creates new {@link WorkloadIdentityCredential} with the configured options set. * @@ -88,6 +94,8 @@ public WorkloadIdentityCredential build() { ValidationUtil.validate(this.getClass().getSimpleName(), LOGGER, "Client ID", clientIdInput, "Tenant ID", tenantIdInput, "Service Token File Path", federatedTokenFilePathInput); + identityClientOptions.setEnableKubernetesTokenProxy(this.enableTokenProxy); + return new WorkloadIdentityCredential(tenantIdInput, clientIdInput, federatedTokenFilePathInput, identityClientOptions.clone()); } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java index d934c23c1d67..4acab6df23ad 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java @@ -71,6 +71,7 @@ public final class IdentityClientOptions implements Cloneable { private List perRetryPolicies; private boolean instanceDiscovery; private String dacEnvConfiguredCredential; + private boolean enableKubernetesTokenProxy; private Duration credentialProcessTimeout = Duration.ofSeconds(10); @@ -833,6 +834,15 @@ public String getDACEnvConfiguredCredential() { return dacEnvConfiguredCredential; } + public boolean isKubernetesTokenProxyEnabled() { + return enableKubernetesTokenProxy; + } + + public IdentityClientOptions setEnableKubernetesTokenProxy(boolean enableTokenProxy) { + this.enableKubernetesTokenProxy = enableTokenProxy; + return this; + } + public IdentityClientOptions clone() { IdentityClientOptions clone = new IdentityClientOptions().setAdditionallyAllowedTenants(this.additionallyAllowedTenants) @@ -863,8 +873,9 @@ public IdentityClientOptions clone() { .setPerRetryPolicies(this.perRetryPolicies) .setBrowserCustomizationOptions(this.browserCustomizationOptions) .setChained(this.isChained) - .subscription(this.subscription); - + .subscription(this.subscription) + .setEnableKubernetesTokenProxy(this.enableKubernetesTokenProxy); + if (isBrokerEnabled()) { clone.setBrokerWindowHandle(this.brokerWindowHandle); clone.setEnableLegacyMsaPassthrough(this.enableMsaPassthrough); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java index b0849d29bf56..0b25b64935b3 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -1,7 +1,6 @@ package com.azure.identity.implementation.customtokenproxy; import com.azure.core.util.logging.ClientLogger; -import com.azure.identity.implementation.WorkloadIdentityTokenProxyPolicy; import java.net.URI; import java.net.URL; @@ -13,7 +12,7 @@ public class CustomTokenProxyConfiguration { - private static final ClientLogger LOGGER = new ClientLogger(WorkloadIdentityTokenProxyPolicy.class); + private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyConfiguration.class); public static final String AZURE_KUBERNETES_TOKEN_PROXY = "AZURE_KUBERNETES_TOKEN_PROXY"; public static final String AZURE_KUBERNETES_CA_FILE = "AZURE_KUBERNETES_CA_FILE"; @@ -40,8 +39,7 @@ public static ProxyConfig parseAndValidate(Configuration configuration) { throw LOGGER.logExceptionAsError(new IllegalStateException( "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present")); } - throw LOGGER.logExceptionAsError(new IllegalStateException( - "AZURE_KUBERNETES_TOKEN_PROXY must be set to enable custom token proxy.")); + return null; } if (!CoreUtils.isNullOrEmpty(caFile) && !CoreUtils.isNullOrEmpty(caData)) { diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java index 194d306eef5b..e37e166a02fe 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -6,15 +6,19 @@ import java.lang.reflect.MalformedParametersException; import java.lang.reflect.Proxy; import java.net.HttpURLConnection; +import java.net.MalformedURLException; import java.net.URI; import java.net.URL; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.security.KeyStore; +import java.security.cert.CertificateException; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; @@ -50,34 +54,12 @@ public Mono send(HttpRequest request) { public HttpResponse sendSync(HttpRequest request, Context context) { try { HttpURLConnection connection = createConnection(request); - connection.connect(); return new CustomTokenProxyHttpResponse(request, connection); } catch (IOException e) { throw new RuntimeException("Failed to create connection to token proxy", e); } } - // private HttpURLConnection createConnection(HttpRequest request) throws IOException { - // URL updateProxyRequest = rewriteTokenRequestForProxy(request.getUrl()); - // HttpsURLConnection connection = (HttpsURLConnection) updateProxyRequest.openConnection(); - // try { - // SSLContext sslContext = getSSLContext(); - // connection.setSSLSocketFactory(sslContext.getSocketFactory()); - - // if(!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { - // SSLParameters sslParameters = connection.getSSLParameters(); - // } - // } - // return connection; - - // // connection.setRequestMethod(request.getMethod().toString()); - // // connection.setDoOutput(true); - // // request.getHeaders().forEach((key, values) -> { - // // values.forEach(value -> connection.addRequestProperty(key, value)); - // // }); - // // return connection; - // } - private HttpURLConnection createConnection(HttpRequest request) throws IOException { URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); @@ -95,8 +77,9 @@ private HttpURLConnection createConnection(HttpRequest request) throws IOExcepti } connection.setRequestMethod(request.getHttpMethod().toString()); - // connection.setConnectTimeout(10_000); - // connection.setReadTimeout(20_000); + connection.setInstanceFollowRedirects(false); + connection.setConnectTimeout(10_000); + connection.setReadTimeout(20_000); connection.setDoOutput(true); request.getHeaders().forEach(header -> { @@ -114,7 +97,7 @@ private HttpURLConnection createConnection(HttpRequest request) throws IOExcepti } - private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedParametersException{ + private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLException{ try { String originalPath = originalUrl.getPath(); String originalQuery = originalUrl.getQuery(); @@ -165,10 +148,9 @@ private SSLContext getSSLContext() { throw new IOException("CA file not found: " + proxyConfig.getCaFile()); } - byte[] currentContent; - + byte[] currentContent = Files.readAllBytes(path); + synchronized (this) { - currentContent = Files.readAllBytes(path); if (currentContent.length == 0) { throw new IOException("CA file " + proxyConfig.getCaFile() + " is empty"); } @@ -182,7 +164,7 @@ private SSLContext getSSLContext() { return cachedSSLContext; } catch (Exception e) { - throw new RuntimeException("Failed to create default SSLContext", e); + throw new RuntimeException("Failed to initialize SSLContext for proxy", e); } } @@ -190,20 +172,42 @@ private SSLContext getSSLContext() { private SSLContext createSslContextFromBytes(byte[] certificateData) { try (InputStream inputStream = new ByteArrayInputStream(certificateData)) { CertificateFactory cf = CertificateFactory.getInstance("X.509"); - X509Certificate caCert = (X509Certificate) cf.generateCertificate(inputStream); - return createSslContext(caCert); + + List certificates = new ArrayList<>(); + // while(inputStream.available() > 0) { + // X509Certificate cert = (X509Certificate) cf.generateCertificate(inputStream); + // certificates.add(cert); + // } + while (true) { + try { + X509Certificate cert = (X509Certificate) cf.generateCertificate(inputStream); + certificates.add(cert); + } catch (CertificateException e) { + break; // end of stream + } + } + + if (certificates.isEmpty()) { + throw new RuntimeException("No valid certificates found"); + } + + // X509Certificate caCert = certificates.get(0); + return createSslContext(certificates); } catch (Exception e) { throw new RuntimeException("Failed to create SSLContext from bytes", e); } } // Create SSLContext from a single X509Certificate - private SSLContext createSslContext(X509Certificate caCert) { + private SSLContext createSslContext(List certificates) { try { KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType()); keystore.load(null, null); - keystore.setCertificateEntry("ca-cert", caCert); - + int index = 1; + for (X509Certificate caCert : certificates) { + keystore.setCertificateEntry("ca-cert-" + index, caCert); + index++; + } TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); tmf.init(keystore); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java index 33d148c86927..88f5d38eb061 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -24,7 +24,7 @@ public class CustomTokenProxyHttpResponse extends HttpResponse { private final int statusCode; private final HttpHeaders headers; private final HttpURLConnection connection; - private byte[] cachedRequestBodyBytes; + private byte[] cachedResponseBodyBytes; public CustomTokenProxyHttpResponse(HttpRequest request, HttpURLConnection connection) { super(request); @@ -83,13 +83,13 @@ public HttpHeaders getHeaders() { @Override public Mono getBodyAsByteArray() { return Mono.fromCallable(() -> { - if (cachedRequestBodyBytes != null) { - return cachedRequestBodyBytes; + if (cachedResponseBodyBytes != null) { + return cachedResponseBodyBytes; } try (InputStream stream = getResponseStream()) { if (stream == null) { - cachedRequestBodyBytes = new byte[0]; - return cachedRequestBodyBytes; + cachedResponseBodyBytes = new byte[0]; + return cachedResponseBodyBytes; } ByteArrayOutputStream buffer = new ByteArrayOutputStream(); byte[] tmp = new byte[4096]; @@ -97,8 +97,8 @@ public Mono getBodyAsByteArray() { while ((n = stream.read(tmp)) != -1) { buffer.write(tmp, 0, n); } - cachedRequestBodyBytes = buffer.toByteArray(); - return cachedRequestBodyBytes; + cachedResponseBodyBytes = buffer.toByteArray(); + return cachedResponseBodyBytes; } }); } @@ -124,24 +124,6 @@ public void close() { connection.disconnect(); } - - - // @Override - // public Flux getBody() { - // // TODO Auto-generated method stub - // throw new UnsupportedOperationException("Unimplemented method 'getBody'"); - // } - - // @Override - // public Mono getBodyAsString() { - // return getBodyAsByteArray().map(bytes -> new String(bytes, StandardCharsets.UTF_8)); - // } - - // @Override - // public Mono getBodyAsString(Charset charset) { - // return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); - // } - private InputStream getResponseStream() throws IOException { try { return connection.getInputStream(); From 85022597fbd1f34a860bbf93662a3b61c0dd01d3 Mon Sep 17 00:00:00 2001 From: anannya03 Date: Sun, 19 Oct 2025 02:54:48 -0700 Subject: [PATCH 3/7] review comments --- .../identity/WorkloadIdentityCredential.java | 5 +- .../implementation/IdentityClientOptions.java | 2 +- .../CustomTokenProxyConfiguration.java | 27 +++--- .../CustomTokenProxyHttpClient.java | 94 +++++++++---------- .../CustomTokenProxyHttpResponse.java | 5 +- .../implementation/util/IdentitySslUtil.java | 9 +- 6 files changed, 71 insertions(+), 71 deletions(-) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java index 3fedf5db30d9..f0653bf91522 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java @@ -92,9 +92,10 @@ public class WorkloadIdentityCredential implements TokenCredential { ClientAssertionCredential tempClientAssertionCredential = null; String tempClientId = null; - if(identityClientOptions.isKubernetesTokenProxyEnabled()) { + if (identityClientOptions.isKubernetesTokenProxyEnabled()) { if (!CustomTokenProxyConfiguration.isConfigured(configuration)) { - throw LOGGER.logExceptionAsError (new IllegalArgumentException("Kubernetes token proxy is enabled but not configured.")); + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Kubernetes token proxy is enabled but not configured.")); } ProxyConfig proxyConfig = CustomTokenProxyConfiguration.parseAndValidate(configuration); identityClientOptions.setHttpClient(new CustomTokenProxyHttpClient(proxyConfig)); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java index 4acab6df23ad..3c2d21d08c56 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java @@ -875,7 +875,7 @@ public IdentityClientOptions clone() { .setChained(this.isChained) .subscription(this.subscription) .setEnableKubernetesTokenProxy(this.enableKubernetesTokenProxy); - + if (isBrokerEnabled()) { clone.setBrokerWindowHandle(this.brokerWindowHandle); clone.setEnableLegacyMsaPassthrough(this.enableMsaPassthrough); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java index 0b25b64935b3..4670457b423b 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -13,13 +13,14 @@ public class CustomTokenProxyConfiguration { private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyConfiguration.class); - + public static final String AZURE_KUBERNETES_TOKEN_PROXY = "AZURE_KUBERNETES_TOKEN_PROXY"; public static final String AZURE_KUBERNETES_CA_FILE = "AZURE_KUBERNETES_CA_FILE"; public static final String AZURE_KUBERNETES_CA_DATA = "AZURE_KUBERNETES_CA_DATA"; public static final String AZURE_KUBERNETES_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME"; - private CustomTokenProxyConfiguration() {} + private CustomTokenProxyConfiguration() { + } public static boolean isConfigured(Configuration configuration) { String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY); @@ -33,8 +34,8 @@ public static ProxyConfig parseAndValidate(Configuration configuration) { String sniName = configuration.get(AZURE_KUBERNETES_SNI_NAME); if (CoreUtils.isNullOrEmpty(tokenProxyUrl)) { - if (!CoreUtils.isNullOrEmpty(sniName) - || !CoreUtils.isNullOrEmpty(caFile) + if (!CoreUtils.isNullOrEmpty(sniName) + || !CoreUtils.isNullOrEmpty(caFile) || !CoreUtils.isNullOrEmpty(caData)) { throw LOGGER.logExceptionAsError(new IllegalStateException( "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present")); @@ -50,7 +51,7 @@ public static ProxyConfig parseAndValidate(Configuration configuration) { URL proxyUrl = validateProxyUrl(tokenProxyUrl); byte[] caCertBytes = null; - if(!CoreUtils.isNullOrEmpty(caData)) { + if (!CoreUtils.isNullOrEmpty(caData)) { try { caCertBytes = caData.getBytes(StandardCharsets.UTF_8); } catch (Exception e) { @@ -76,23 +77,23 @@ private static URL validateProxyUrl(String endpoint) { } if (tokenProxy.getRawUserInfo() != null) { - throw LOGGER.logExceptionAsError(new IllegalArgumentException( - "Custom token endpoint URL must not contain user info: " + endpoint)); + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain user info: " + endpoint)); } if (tokenProxy.getRawQuery() != null) { - throw LOGGER.logExceptionAsError(new IllegalArgumentException( - "Custom token endpoint URL must not contain a query: " + endpoint)); + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain a query: " + endpoint)); } if (tokenProxy.getRawFragment() != null) { - throw LOGGER.logExceptionAsError(new IllegalArgumentException( - "Custom token endpoint URL must not contain a fragment: " + endpoint)); + throw LOGGER.logExceptionAsError( + new IllegalArgumentException("Custom token endpoint URL must not contain a fragment: " + endpoint)); } if (tokenProxy.getRawPath() == null || tokenProxy.getRawPath().isEmpty()) { - tokenProxy = new URI(tokenProxy.getScheme(), null, tokenProxy.getHost(), - tokenProxy.getPort(), "/", null, null); + tokenProxy = new URI(tokenProxy.getScheme(), null, tokenProxy.getHost(), tokenProxy.getPort(), "/", + null, null); } return tokenProxy.toURL(); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java index e37e166a02fe..9794dc4a5609 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -3,8 +3,6 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.MalformedParametersException; -import java.lang.reflect.Proxy; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.URI; @@ -22,7 +20,6 @@ import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManagerFactory; @@ -49,70 +46,71 @@ public CustomTokenProxyHttpClient(ProxyConfig proxyConfig) { public Mono send(HttpRequest request) { return Mono.fromCallable(() -> sendSync(request, Context.NONE)); } - + @Override public HttpResponse sendSync(HttpRequest request, Context context) { - try { - HttpURLConnection connection = createConnection(request); - return new CustomTokenProxyHttpResponse(request, connection); - } catch (IOException e) { - throw new RuntimeException("Failed to create connection to token proxy", e); - } + try { + HttpURLConnection connection = createConnection(request); + return new CustomTokenProxyHttpResponse(request, connection); + } catch (IOException e) { + throw new RuntimeException("Failed to create connection to token proxy", e); + } } private HttpURLConnection createConnection(HttpRequest request) throws IOException { - URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); - HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); - - // If SNI explicitly provided - try { - SSLContext sslContext = getSSLContext(); - SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); - if(!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { - sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, proxyConfig.getSniName()); + URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); + HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); + + // If SNI explicitly provided + try { + SSLContext sslContext = getSSLContext(); + SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); + if (!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { + sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, proxyConfig.getSniName()); + } + connection.setSSLSocketFactory(sslSocketFactory); + } catch (Exception e) { + throw new RuntimeException("Failed to set up SSL context for token proxy", e); } - connection.setSSLSocketFactory(sslSocketFactory); - } catch (Exception e) { - throw new RuntimeException("Failed to set up SSL context for token proxy", e); - } - connection.setRequestMethod(request.getHttpMethod().toString()); - connection.setInstanceFollowRedirects(false); - connection.setConnectTimeout(10_000); - connection.setReadTimeout(20_000); - connection.setDoOutput(true); + connection.setRequestMethod(request.getHttpMethod().toString()); + connection.setInstanceFollowRedirects(false); + connection.setConnectTimeout(10_000); + connection.setReadTimeout(20_000); + connection.setDoOutput(true); - request.getHeaders().forEach(header -> { - connection.addRequestProperty(header.getName(), header.getValue()); - }); + request.getHeaders().forEach(header -> { + connection.addRequestProperty(header.getName(), header.getValue()); + }); - if (request.getBodyAsBinaryData() != null) { - byte[] bytes = request.getBodyAsBinaryData().toBytes(); - if (bytes != null && bytes.length > 0) { - connection.getOutputStream().write(bytes); + if (request.getBodyAsBinaryData() != null) { + byte[] bytes = request.getBodyAsBinaryData().toBytes(); + if (bytes != null && bytes.length > 0) { + connection.getOutputStream().write(bytes); + } } - } - return connection; + return connection; } - - private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLException{ + private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLException { try { String originalPath = originalUrl.getPath(); String originalQuery = originalUrl.getQuery(); String tokenProxyBase = proxyConfig.getTokenProxyUrl().toString(); - if(!tokenProxyBase.endsWith("/")) tokenProxyBase += "/"; + if (!tokenProxyBase.endsWith("/")) + tokenProxyBase += "/"; - URI combined = URI.create(tokenProxyBase).resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath); + URI combined = URI.create(tokenProxyBase) + .resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath); String combinedStr = combined.toString(); if (originalQuery != null && !originalQuery.isEmpty()) { combinedStr += "?" + originalQuery; } - return new URL(combinedStr); + return new URL(combinedStr); } catch (Exception e) { throw new RuntimeException("Failed to rewrite token request for proxy", e); @@ -121,15 +119,15 @@ private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLExce private SSLContext getSSLContext() { try { - // If no CA override provide, use default - if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile()) - && (proxyConfig.getCaData() == null || proxyConfig.getCaData().length == 0)) { + // If no CA override provided, use default + if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile()) + && (proxyConfig.getCaData() == null || proxyConfig.getCaData().length == 0)) { synchronized (this) { if (cachedSSLContext == null) { cachedSSLContext = SSLContext.getDefault(); } } - return cachedSSLContext; + return cachedSSLContext; } // If CA data provided, use it @@ -139,7 +137,7 @@ private SSLContext getSSLContext() { cachedSSLContext = createSslContextFromBytes(proxyConfig.getCaData()); } } - return cachedSSLContext; + return cachedSSLContext; } // If CA file provided, read it (and re-read if it changes) @@ -164,7 +162,7 @@ private SSLContext getSSLContext() { return cachedSSLContext; } catch (Exception e) { - throw new RuntimeException("Failed to initialize SSLContext for proxy", e); + throw new RuntimeException("Failed to initialize SSLContext for proxy", e); } } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java index 88f5d38eb061..f9fe17265891 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -46,7 +46,7 @@ private HttpHeaders extractHeaders(HttpURLConnection connection) { return headers; } - public int extractStatusCode(HttpURLConnection connection) { + private int extractStatusCode(HttpURLConnection connection) { try { return connection.getResponseCode(); } catch (IOException e) { @@ -106,7 +106,7 @@ public Mono getBodyAsByteArray() { @Override public Flux getBody() { return getBodyAsByteArray().flatMapMany(bytes -> Flux.just(ByteBuffer.wrap(bytes))); - } + } @Override public Mono getBodyAsString() { @@ -117,7 +117,6 @@ public Mono getBodyAsString() { public Mono getBodyAsString(Charset charset) { return getBodyAsByteArray().map(bytes -> new String(bytes, charset)); } - @Override public void close() { diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java index 604dba0b59bf..01d0e7edf894 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java @@ -9,7 +9,6 @@ import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SNIHostName; -import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; @@ -161,7 +160,8 @@ public Socket createSocket(String host, int port) throws IOException { } @Override - public Socket createSocket(String host, int port, java.net.InetAddress localAddress, int localPort) throws IOException { + public Socket createSocket(String host, int port, java.net.InetAddress localAddress, int localPort) + throws IOException { Socket socket = sslSocketFactory.createSocket(host, port, localAddress, localPort); configureSni(socket); return socket; @@ -175,7 +175,8 @@ public Socket createSocket(java.net.InetAddress host, int port) throws IOExcepti } @Override - public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, int localPort) throws IOException { + public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, + int localPort) throws IOException { Socket socket = sslSocketFactory.createSocket(address, port, localAddress, localPort); configureSni(socket); return socket; @@ -191,7 +192,7 @@ public String[] getSupportedCipherSuites() { return sslSocketFactory.getSupportedCipherSuites(); } - private void configureSni(Socket socket) { + private void configureSni(Socket socket) { if (socket instanceof SSLSocket && !CoreUtils.isNullOrEmpty(sniName)) { SSLSocket sslSocket = (SSLSocket) socket; SSLParameters sslParameters = sslSocket.getSSLParameters(); From 7a521b50c565a273e79de632e1b242644cf2abd9 Mon Sep 17 00:00:00 2001 From: anannya03 Date: Thu, 23 Oct 2025 12:12:27 -0700 Subject: [PATCH 4/7] added wic testcases and some review comments --- .../identity/WorkloadIdentityCredential.java | 8 +- .../WorkloadIdentityCredentialBuilder.java | 4 +- .../CustomTokenProxyConfiguration.java | 4 +- .../CustomTokenProxyHttpClient.java | 6 +- .../CustomTokenProxyHttpResponse.java | 1 - .../WorkloadIdentityCredentialTest.java | 302 ++++++++++++++++++ 6 files changed, 314 insertions(+), 11 deletions(-) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java index f0653bf91522..40c3ac8ce4b4 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java @@ -93,12 +93,10 @@ public class WorkloadIdentityCredential implements TokenCredential { String tempClientId = null; if (identityClientOptions.isKubernetesTokenProxyEnabled()) { - if (!CustomTokenProxyConfiguration.isConfigured(configuration)) { - throw LOGGER.logExceptionAsError( - new IllegalArgumentException("Kubernetes token proxy is enabled but not configured.")); + if (CustomTokenProxyConfiguration.isConfigured(configuration)) { + ProxyConfig proxyConfig = CustomTokenProxyConfiguration.parseAndValidate(configuration); + identityClientOptions.setHttpClient(new CustomTokenProxyHttpClient(proxyConfig)); } - ProxyConfig proxyConfig = CustomTokenProxyConfiguration.parseAndValidate(configuration); - identityClientOptions.setHttpClient(new CustomTokenProxyHttpClient(proxyConfig)); } if (!(CoreUtils.isNullOrEmpty(tenantIdInput) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java index 5520ad122bf4..659ebe920c5d 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java @@ -67,8 +67,8 @@ public WorkloadIdentityCredentialBuilder tokenFilePath(String tokenFilePath) { return this; } - public WorkloadIdentityCredentialBuilder enableKubernetesTokenProxy(boolean enable) { - this.enableTokenProxy = enable; + public WorkloadIdentityCredentialBuilder enableKubernetesTokenProxy() { + this.enableTokenProxy = true; return this; } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java index 4670457b423b..9e079bede725 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -37,14 +37,14 @@ public static ProxyConfig parseAndValidate(Configuration configuration) { if (!CoreUtils.isNullOrEmpty(sniName) || !CoreUtils.isNullOrEmpty(caFile) || !CoreUtils.isNullOrEmpty(caData)) { - throw LOGGER.logExceptionAsError(new IllegalStateException( + throw LOGGER.logExceptionAsError(new IllegalArgumentException( "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present")); } return null; } if (!CoreUtils.isNullOrEmpty(caFile) && !CoreUtils.isNullOrEmpty(caData)) { - throw LOGGER.logExceptionAsError(new IllegalStateException( + throw LOGGER.logExceptionAsError(new IllegalArgumentException( "Only one of AZURE_KUBERNETES_CA_FILE or AZURE_KUBERNETES_CA_DATA can be set.")); } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java index 9794dc4a5609..fb2adc614f45 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -3,6 +3,7 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.URI; @@ -86,7 +87,10 @@ private HttpURLConnection createConnection(HttpRequest request) throws IOExcepti if (request.getBodyAsBinaryData() != null) { byte[] bytes = request.getBodyAsBinaryData().toBytes(); if (bytes != null && bytes.length > 0) { - connection.getOutputStream().write(bytes); + try (OutputStream os = connection.getOutputStream()) { + os.write(bytes); + os.flush(); + } } } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java index f9fe17265891..0c4d597231a2 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -14,7 +14,6 @@ import com.azure.core.http.HttpHeaders; import com.azure.core.http.HttpRequest; import com.azure.core.http.HttpResponse; - import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; diff --git a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java index d2404d9b302b..0f11e051337c 100644 --- a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java +++ b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java @@ -35,6 +35,10 @@ public class WorkloadIdentityCredentialTest { private static final String CLIENT_ID = UUID.randomUUID().toString(); + private static final String ENV_PROXY_URL = "AZURE_KUBERNETES_TOKEN_PROXY"; + private static final String ENV_CA_FILE = "AZURE_KUBERNETES_CA_FILE"; + private static final String ENV_CA_DATA = "AZURE_KUBERNETES_CA_DATA"; + private static final String ENV_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME"; @Test public void testWorkloadIdentityFlow(@TempDir Path tempDir) throws IOException { @@ -192,4 +196,302 @@ public void testFileReadingError(@TempDir Path tempDir) { assertTrue(error.getCause() instanceof IOException); // Original IOException from Files.readAllBytes }).verify(); } + + @Test + public void testProxyEnabledWithProxyUrlGetsToken(@TempDir Path tempDir) throws IOException { + // setup + String endpoint = "https://localhost"; + String token1 = "token1"; + String proxyUrl = "https://token-proxy.example.com"; + + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, proxyUrl)); + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + try (MockedConstruction identityClientMock + = mockConstruction(IdentityClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty()); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt)); + })) { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + StepVerifier.create(credential.getToken(request1)) + .expectNextMatches(token -> token1.equals(token.getToken()) + && expiresAt.getSecond() == token.getExpiresAt().getSecond()) + .verifyComplete(); + + assertNotNull(identityClientMock); + } + } + + @Test + public void testProxyEnabledWithoutProxyUrlGetsToken(@TempDir Path tempDir) throws IOException { + // setup + String endpoint = "https://localhost"; + String token1 = "token1"; + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)); + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + try (MockedConstruction identityClientMock + = mockConstruction(IdentityClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty()); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt)); + })) { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + StepVerifier.create(credential.getToken(request1)) + .expectNextMatches(token -> token1.equals(token.getToken()) + && expiresAt.getSecond() == token.getExpiresAt().getSecond()) + .verifyComplete(); + + assertNotNull(identityClientMock); + } + } + + @Test + public void testProxyEnabledInvalidProxyUrlSchemeFailure(@TempDir Path tempDir) throws IOException { + String endpoint = "https://localhost"; + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "http://not-https.example.com")); + + Assertions.assertThrows(IllegalArgumentException.class, () -> { + new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + }); + } + + @Test + public void testProxyUrlWithQueryFailure(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://proxy.example.com?x=y")); + + Assertions.assertThrows(IllegalArgumentException.class, () -> { + new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + }); + } + + @Test + public void testProxyUrlWithFragmentFailure(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://proxy.example.com#frag")); + + Assertions.assertThrows(IllegalArgumentException.class, () -> { + new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + }); + } + + @Test + public void testProxyUrlWithUserInfoFailure(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://user:pass@proxy.example.com")); + + Assertions.assertThrows(IllegalArgumentException.class, () -> { + new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + }); + } + + @Test + public void testCaFileAndCaDataPresentFailure(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + Path caFile = tempDir.resolve("ca.crt"); + Files.write(caFile, + "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----\n".getBytes(StandardCharsets.UTF_8)); + + String caData = "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----"; + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://proxy.example.com") + .put(ENV_CA_FILE, caFile.toString()) + .put(ENV_CA_DATA, caData)); + + Assertions.assertThrows(IllegalArgumentException.class, () -> { + new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + }); + } + + @Test + public void testProxyEnabledWithProxyUrlGetsTokenSync(@TempDir Path tempDir) throws IOException { + // setup + String endpoint = "https://localhost"; + String token1 = "token1"; + String proxyUrl = "https://token-proxy.example.com"; + + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, proxyUrl)); + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + try (MockedConstruction identitySyncClientMock + = mockConstruction(IdentitySyncClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())) + .thenThrow(new IllegalStateException("Test")); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessTokenSync(token1, expiresAt)); + })) { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + AccessToken token = credential.getTokenSync(request1); + + assertTrue(token1.equals(token.getToken())); + assertTrue(expiresAt.getSecond() == token.getExpiresAt().getSecond()); + assertNotNull(identitySyncClientMock); + } + } + + @Test + public void testProxyUrlWithCaDataAcquiresToken(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + String token1 = "token-ca-data"; + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + + String caData = "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----"; + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://token-proxy.example.com") + .put(ENV_CA_DATA, caData)); + + try (MockedConstruction mocked + = mockConstruction(IdentityClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty()); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt)); + })) { + WorkloadIdentityCredential cred = new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + StepVerifier.create(cred.getToken(request1)) + .expectNextMatches(token -> token1.equals(token.getToken()) + && token.getExpiresAt().getSecond() == expiresAt.getSecond()) + .verifyComplete(); + + assertNotNull(mocked); + } + } + + @Test + public void testProxyUrlWithCaFileGetsToken(@TempDir Path tempDir) throws IOException { + String endpoint = "https://login.microsoftonline.com"; + String token1 = "tok-ca-file"; + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + Path caFile = tempDir.resolve("ca.crt"); + Files.write(caFile, + "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----\n".getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, "https://token-proxy.example.com") + .put(ENV_CA_FILE, caFile.toString())); + + try (MockedConstruction mocked + = mockConstruction(IdentityClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty()); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt)); + })) { + WorkloadIdentityCredential cred = new WorkloadIdentityCredentialBuilder().tenantId("tenant") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + StepVerifier.create(cred.getToken(request1)) + .expectNextMatches(token -> token1.equals(token.getToken()) + && token.getExpiresAt().getSecond() == expiresAt.getSecond()) + .verifyComplete(); + + assertNotNull(mocked); + } + } + } From 0c69110a83052c322b4003801ae1605d135de22b Mon Sep 17 00:00:00 2001 From: anannya03 Date: Sun, 26 Oct 2025 12:27:33 -0700 Subject: [PATCH 5/7] unit test --- .../customtokenproxy/ProxyConfig.java | 3 ++ .../WorkloadIdentityCredentialTest.java | 40 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java index a104a8ab7053..098b25198fd3 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + package com.azure.identity.implementation.customtokenproxy; import java.net.URL; diff --git a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java index 0f11e051337c..e52c3ef8735f 100644 --- a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java +++ b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java @@ -494,4 +494,44 @@ public void testProxyUrlWithCaFileGetsToken(@TempDir Path tempDir) throws IOExce } } + @Test + public void testProxyEnabledWithSniNameGetsToken(@TempDir Path tempDir) throws IOException { + // setup + String endpoint = "https://localhost"; + String token1 = "token1"; + String proxyUrl = "https://token-proxy.example.com"; + String sniName = "615f3b8ad7eb011a09ed3b762e404de43ebc7ade0802a34c9fd322b688c3a655.ests.aks"; + + TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1); + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint) + .put(ENV_PROXY_URL, proxyUrl) + .put(ENV_SNI_NAME, sniName)); + + Path tokenFile = tempDir.resolve("token.txt"); + Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8)); + + try (MockedConstruction identityClientMock + = mockConstruction(IdentityClient.class, (identityClient, context) -> { + when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty()); + when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class))) + .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt)); + })) { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid") + .clientId(CLIENT_ID) + .tokenFilePath(tokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .build(); + + StepVerifier.create(credential.getToken(request1)) + .expectNextMatches(token -> token1.equals(token.getToken()) + && expiresAt.getSecond() == token.getExpiresAt().getSecond()) + .verifyComplete(); + + assertNotNull(identityClientMock); + } + } + } From 89396ed596e61caa7a19ca72522544041e01344e Mon Sep 17 00:00:00 2001 From: anannya03 Date: Tue, 28 Oct 2025 20:21:21 -0700 Subject: [PATCH 6/7] live testing based updates --- .../CustomTokenProxyConfiguration.java | 16 ++++-------- .../CustomTokenProxyHttpClient.java | 25 +++++++++++-------- .../CustomTokenProxyHttpResponse.java | 15 +++-------- .../implementation/util/IdentitySslUtil.java | 12 ++++++--- 4 files changed, 31 insertions(+), 37 deletions(-) diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java index 9e079bede725..312450e088b0 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + package com.azure.identity.implementation.customtokenproxy; import com.azure.core.util.logging.ClientLogger; @@ -60,7 +63,8 @@ public static ProxyConfig parseAndValidate(Configuration configuration) { } } - return new ProxyConfig(proxyUrl, sniName, caFile, caCertBytes); + ProxyConfig config = new ProxyConfig(proxyUrl, sniName, caFile, caCertBytes); + return config; } private static URL validateProxyUrl(String endpoint) { @@ -105,14 +109,4 @@ private static URL validateProxyUrl(String endpoint) { } } - // public static String getTokenProxyUrl() { - // String tokenProxyUrl = System.getenv(AZURE_KUBERNETES_TOKEN_PROXY); - // if (tokenProxyUrl == null || tokenProxyUrl.isEmpty()) { - // throw LOGGER.logExceptionAsError(new IllegalStateException( - // String.format("Environment variable '%s' is not set or is empty. It must be set to the URL of the" - // + " token proxy.", AZURE_KUBERNETES_TOKEN_PROXY))); - // } - // return tokenProxyUrl; - // } - } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java index fb2adc614f45..88997555573c 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + package com.azure.identity.implementation.customtokenproxy; import java.io.ByteArrayInputStream; @@ -19,6 +22,7 @@ import java.util.Arrays; import java.util.List; +import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; @@ -61,7 +65,6 @@ public HttpResponse sendSync(HttpRequest request, Context context) { private HttpURLConnection createConnection(HttpRequest request) throws IOException { URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl()); HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection(); - // If SNI explicitly provided try { SSLContext sslContext = getSSLContext(); @@ -70,6 +73,7 @@ private HttpURLConnection createConnection(HttpRequest request) throws IOExcepti sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, proxyConfig.getSniName()); } connection.setSSLSocketFactory(sslSocketFactory); + connection.setHostnameVerifier(sniAwareVerifier(proxyConfig.getSniName(), proxyConfig.getTokenProxyUrl())); } catch (Exception e) { throw new RuntimeException("Failed to set up SSL context for token proxy", e); } @@ -93,7 +97,6 @@ private HttpURLConnection createConnection(HttpRequest request) throws IOExcepti } } } - return connection; } @@ -162,7 +165,6 @@ private SSLContext getSSLContext() { cachedFileContent = currentContent; } } - return cachedSSLContext; } catch (Exception e) { @@ -176,24 +178,18 @@ private SSLContext createSslContextFromBytes(byte[] certificateData) { CertificateFactory cf = CertificateFactory.getInstance("X.509"); List certificates = new ArrayList<>(); - // while(inputStream.available() > 0) { - // X509Certificate cert = (X509Certificate) cf.generateCertificate(inputStream); - // certificates.add(cert); - // } while (true) { try { X509Certificate cert = (X509Certificate) cf.generateCertificate(inputStream); certificates.add(cert); } catch (CertificateException e) { - break; // end of stream + break; } } if (certificates.isEmpty()) { throw new RuntimeException("No valid certificates found"); } - - // X509Certificate caCert = certificates.get(0); return createSslContext(certificates); } catch (Exception e) { throw new RuntimeException("Failed to create SSLContext from bytes", e); @@ -221,4 +217,13 @@ private SSLContext createSslContext(List certificates) { } } + private static HostnameVerifier sniAwareVerifier(String sniName, URL proxyUrl) { + return (urlHost, session) -> { + String peerHost = session.getPeerHost(); + String expectedProxyHost = proxyUrl.getHost(); + return peerHost.equalsIgnoreCase(expectedProxyHost) + || (!CoreUtils.isNullOrEmpty(sniName) && peerHost.equalsIgnoreCase(sniName)); + }; + } + } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java index 0c4d597231a2..973c04c48891 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + package com.azure.identity.implementation.customtokenproxy; import java.io.ByteArrayOutputStream; @@ -68,17 +71,6 @@ public HttpHeaders getHeaders() { return headers; } - // @Override - // public Mono getBodyAsByteArray() { - // return Mono.fromCallable(() -> { - // try (InputStream inputStream = connection.getInputStream()) { - // return inputStream.readAllBytes(); - // } catch (IOException e) { - // throw new RuntimeException("Failed to read body from token proxy response", e); - // } - // }); - // } - @Override public Mono getBodyAsByteArray() { return Mono.fromCallable(() -> { @@ -126,7 +118,6 @@ private InputStream getResponseStream() throws IOException { try { return connection.getInputStream(); } catch (IOException e) { - // On non-2xx responses, getInputStream() throws, use error stream instead return connection.getErrorStream(); } } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java index 01d0e7edf894..5f6b7161cd69 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java @@ -8,7 +8,7 @@ import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; -import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; @@ -19,6 +19,7 @@ import java.io.IOException; import java.net.Socket; +import java.nio.charset.StandardCharsets; import java.security.KeyManagementException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -26,9 +27,7 @@ import java.security.cert.CertificateEncodingException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; -import java.util.ArrayList; import java.util.Collections; -import java.util.List; public final class IdentitySslUtil { public static final HostnameVerifier ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER; @@ -196,10 +195,15 @@ private void configureSni(Socket socket) { if (socket instanceof SSLSocket && !CoreUtils.isNullOrEmpty(sniName)) { SSLSocket sslSocket = (SSLSocket) socket; SSLParameters sslParameters = sslSocket.getSSLParameters(); - sslParameters.setServerNames(Collections.singletonList(new SNIHostName(sniName))); + sslParameters.setServerNames(Collections.singletonList(new RawSniServerName(sniName))); sslSocket.setSSLParameters(sslParameters); } } + } + public static final class RawSniServerName extends SNIServerName { + public RawSniServerName(String sniHost) { + super(0, sniHost.getBytes(StandardCharsets.UTF_8)); + } } } From 702cc3f5e33a4dc52bc02f06e2454501f0caa539 Mon Sep 17 00:00:00 2001 From: anannya03 Date: Wed, 29 Oct 2025 02:17:05 -0700 Subject: [PATCH 7/7] spotless, spotbugs, checkstyle and local server test --- sdk/identity/azure-identity/pom.xml | 7 + .../WorkloadIdentityCredentialBuilder.java | 8 + .../CustomTokenProxyConfiguration.java | 4 +- .../CustomTokenProxyHttpClient.java | 69 ++- .../CustomTokenProxyHttpResponse.java | 27 +- .../customtokenproxy/ProxyConfig.java | 4 +- ...IdentityCredentialIdentityBindingTest.java | 555 ++++++++++++++++++ 7 files changed, 637 insertions(+), 37 deletions(-) create mode 100644 sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialIdentityBindingTest.java diff --git a/sdk/identity/azure-identity/pom.xml b/sdk/identity/azure-identity/pom.xml index 82dd64b7a0f5..a3eb2cc64d4a 100644 --- a/sdk/identity/azure-identity/pom.xml +++ b/sdk/identity/azure-identity/pom.xml @@ -115,6 +115,13 @@ 1.17.7 test + + + org.bouncycastle + bcpkix-lts8on + 2.73.8 + test + diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java index 659ebe920c5d..31e62c6f4e14 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java @@ -67,6 +67,14 @@ public WorkloadIdentityCredentialBuilder tokenFilePath(String tokenFilePath) { return this; } + /** + * Enables the Kubernetes token proxy feature for AKS workload identity scenarios. + * When enabled, the credential will attempt to use a custom token proxy configured through + * environment variables (AZURE_KUBERNETES_TOKEN_PROXY, AZURE_KUBERNETES_CA_FILE, + * AZURE_KUBERNETES_CA_DATA, AZURE_KUBERNETES_SNI_NAME). + * + * @return An updated instance of this builder with Kubernetes token proxy enabled. + */ public WorkloadIdentityCredentialBuilder enableKubernetesTokenProxy() { this.enableTokenProxy = true; return this; diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java index 312450e088b0..da7eddc1e38b 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. package com.azure.identity.implementation.customtokenproxy; @@ -13,7 +13,7 @@ import com.azure.core.util.Configuration; import com.azure.core.util.CoreUtils; -public class CustomTokenProxyConfiguration { +public final class CustomTokenProxyConfiguration { private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyConfiguration.class); diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java index 88997555573c..4cad1ba381be 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java @@ -1,5 +1,5 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. package com.azure.identity.implementation.customtokenproxy; @@ -33,18 +33,29 @@ import com.azure.core.http.HttpResponse; import com.azure.core.util.Context; import com.azure.core.util.CoreUtils; +import com.azure.core.util.logging.ClientLogger; import com.azure.identity.implementation.util.IdentitySslUtil; import reactor.core.publisher.Mono; public class CustomTokenProxyHttpClient implements HttpClient { + private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyHttpClient.class); + private final ProxyConfig proxyConfig; private volatile SSLContext cachedSSLContext; private volatile byte[] cachedFileContent; + private final URL proxyUrl; + private final String sniName; + private final byte[] caData; + private final String caFile; public CustomTokenProxyHttpClient(ProxyConfig proxyConfig) { this.proxyConfig = proxyConfig; + this.proxyUrl = proxyConfig.getTokenProxyUrl(); + this.sniName = proxyConfig.getSniName(); + this.caData = proxyConfig.getCaData(); + this.caFile = proxyConfig.getCaFile(); } @Override @@ -58,7 +69,7 @@ public HttpResponse sendSync(HttpRequest request, Context context) { HttpURLConnection connection = createConnection(request); return new CustomTokenProxyHttpResponse(request, connection); } catch (IOException e) { - throw new RuntimeException("Failed to create connection to token proxy", e); + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to create connection to token proxy", e)); } } @@ -69,20 +80,28 @@ private HttpURLConnection createConnection(HttpRequest request) throws IOExcepti try { SSLContext sslContext = getSSLContext(); SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); - if (!CoreUtils.isNullOrEmpty(proxyConfig.getSniName())) { - sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, proxyConfig.getSniName()); + if (!CoreUtils.isNullOrEmpty(sniName)) { + sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, sniName); } connection.setSSLSocketFactory(sslSocketFactory); - connection.setHostnameVerifier(sniAwareVerifier(proxyConfig.getSniName(), proxyConfig.getTokenProxyUrl())); + connection.setHostnameVerifier(sniAwareVerifier(sniName, proxyUrl)); } catch (Exception e) { - throw new RuntimeException("Failed to set up SSL context for token proxy", e); + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to set up SSL context for token proxy", e)); } - connection.setRequestMethod(request.getHttpMethod().toString()); + String method = request.getHttpMethod().toString(); + connection.setRequestMethod(method); connection.setInstanceFollowRedirects(false); connection.setConnectTimeout(10_000); connection.setReadTimeout(20_000); - connection.setDoOutput(true); + + boolean hasBody = request.getBodyAsBinaryData() != null && request.getBodyAsBinaryData().toBytes() != null; + + if (hasBody && ("POST".equals(method) || "PUT".equals(method) || "PATCH".equals(method))) { + connection.setDoOutput(true); + } else { + connection.setDoOutput(false); + } request.getHeaders().forEach(header -> { connection.addRequestProperty(header.getName(), header.getValue()); @@ -105,9 +124,10 @@ private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLExce String originalPath = originalUrl.getPath(); String originalQuery = originalUrl.getQuery(); - String tokenProxyBase = proxyConfig.getTokenProxyUrl().toString(); - if (!tokenProxyBase.endsWith("/")) + String tokenProxyBase = proxyUrl.toString(); + if (!tokenProxyBase.endsWith("/")) { tokenProxyBase += "/"; + } URI combined = URI.create(tokenProxyBase) .resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath); @@ -120,15 +140,14 @@ private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLExce return new URL(combinedStr); } catch (Exception e) { - throw new RuntimeException("Failed to rewrite token request for proxy", e); + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to rewrite token request for proxy", e)); } } private SSLContext getSSLContext() { try { // If no CA override provided, use default - if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile()) - && (proxyConfig.getCaData() == null || proxyConfig.getCaData().length == 0)) { + if (CoreUtils.isNullOrEmpty(caFile) && (caData == null || caData.length == 0)) { synchronized (this) { if (cachedSSLContext == null) { cachedSSLContext = SSLContext.getDefault(); @@ -138,26 +157,26 @@ private SSLContext getSSLContext() { } // If CA data provided, use it - if (CoreUtils.isNullOrEmpty(proxyConfig.getCaFile())) { + if (CoreUtils.isNullOrEmpty(caFile)) { synchronized (this) { if (cachedSSLContext == null) { - cachedSSLContext = createSslContextFromBytes(proxyConfig.getCaData()); + cachedSSLContext = createSslContextFromBytes(caData); } } return cachedSSLContext; } // If CA file provided, read it (and re-read if it changes) - Path path = Paths.get(proxyConfig.getCaFile()); + Path path = Paths.get(caFile); if (!Files.exists(path)) { - throw new IOException("CA file not found: " + proxyConfig.getCaFile()); + throw LOGGER.logExceptionAsError(new RuntimeException("CA file not found: " + caFile)); } byte[] currentContent = Files.readAllBytes(path); synchronized (this) { if (currentContent.length == 0) { - throw new IOException("CA file " + proxyConfig.getCaFile() + " is empty"); + throw LOGGER.logExceptionAsError(new RuntimeException("CA file " + caFile + " is empty")); } if (cachedSSLContext == null || !Arrays.equals(currentContent, cachedFileContent)) { @@ -168,7 +187,7 @@ private SSLContext getSSLContext() { return cachedSSLContext; } catch (Exception e) { - throw new RuntimeException("Failed to initialize SSLContext for proxy", e); + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to initialize SSLContext for proxy", e)); } } @@ -188,11 +207,11 @@ private SSLContext createSslContextFromBytes(byte[] certificateData) { } if (certificates.isEmpty()) { - throw new RuntimeException("No valid certificates found"); + throw LOGGER.logExceptionAsError(new RuntimeException("No valid certificates found")); } return createSslContext(certificates); } catch (Exception e) { - throw new RuntimeException("Failed to create SSLContext from bytes", e); + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to create SSLContext from bytes", e)); } } @@ -213,14 +232,14 @@ private SSLContext createSslContext(List certificates) { context.init(null, tmf.getTrustManagers(), null); return context; } catch (Exception e) { - throw new RuntimeException("Failed to create SSLContext", e); + throw LOGGER.logExceptionAsError(new RuntimeException("Failed to create SSLContext", e)); } } - private static HostnameVerifier sniAwareVerifier(String sniName, URL proxyUrl) { + private static HostnameVerifier sniAwareVerifier(String sniName, URL customProxyUrl) { return (urlHost, session) -> { String peerHost = session.getPeerHost(); - String expectedProxyHost = proxyUrl.getHost(); + String expectedProxyHost = customProxyUrl.getHost(); return peerHost.equalsIgnoreCase(expectedProxyHost) || (!CoreUtils.isNullOrEmpty(sniName) && peerHost.equalsIgnoreCase(sniName)); }; diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java index 973c04c48891..8eb14cea2764 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java @@ -1,5 +1,5 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. package com.azure.identity.implementation.customtokenproxy; @@ -17,10 +17,13 @@ import com.azure.core.http.HttpHeaders; import com.azure.core.http.HttpRequest; import com.azure.core.http.HttpResponse; +import com.azure.core.util.logging.ClientLogger; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -public class CustomTokenProxyHttpResponse extends HttpResponse { +public final class CustomTokenProxyHttpResponse extends HttpResponse { + + private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyHttpResponse.class); // private final HttpRequest request; private final int statusCode; @@ -52,7 +55,8 @@ private int extractStatusCode(HttpURLConnection connection) { try { return connection.getResponseCode(); } catch (IOException e) { - throw new RuntimeException("Failed to get status code from token proxy response", e); + throw LOGGER + .logExceptionAsError(new RuntimeException("Failed to get status code from token proxy response", e)); } } @@ -77,19 +81,26 @@ public Mono getBodyAsByteArray() { if (cachedResponseBodyBytes != null) { return cachedResponseBodyBytes; } - try (InputStream stream = getResponseStream()) { + + InputStream stream = null; + try { + stream = getResponseStream(); if (stream == null) { cachedResponseBodyBytes = new byte[0]; return cachedResponseBodyBytes; } ByteArrayOutputStream buffer = new ByteArrayOutputStream(); - byte[] tmp = new byte[4096]; int n; - while ((n = stream.read(tmp)) != -1) { - buffer.write(tmp, 0, n); + byte[] temp = new byte[4096]; + while ((n = stream.read(temp)) != -1) { + buffer.write(temp, 0, n); } cachedResponseBodyBytes = buffer.toByteArray(); return cachedResponseBodyBytes; + } finally { + if (stream != null) { + stream.close(); + } } }); } diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java index 098b25198fd3..171087bac6e9 100644 --- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java +++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java @@ -1,5 +1,5 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. package com.azure.identity.implementation.customtokenproxy; diff --git a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialIdentityBindingTest.java b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialIdentityBindingTest.java new file mode 100644 index 000000000000..1c3cd5581efc --- /dev/null +++ b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialIdentityBindingTest.java @@ -0,0 +1,555 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.identity; + +import com.azure.core.credential.AccessToken; +import com.azure.core.credential.TokenRequestContext; +import com.azure.core.test.utils.TestConfigurationSource; +import com.azure.core.util.Configuration; +import com.azure.identity.util.TestUtils; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslHandler; +import io.netty.handler.ssl.SslProvider; +import io.netty.util.CharsetUtil; +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.X509v3CertificateBuilder; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import javax.net.ssl.SSLEngine; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.math.BigInteger; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.PrivateKey; +import java.security.Security; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.time.OffsetDateTime; +import java.util.Base64; +import java.util.Date; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * E2E test for WorkloadIdentityCredential with AKS Custom Token Proxy using Netty mock server. + * + */ +public class WorkloadIdentityCredentialIdentityBindingTest { + + private static final String AKS_SNI_NAME = "test-aks-proxy.ests.aks"; + + private static final String TEST_CLIENT_ID = "test-client-id"; + private static final String TEST_TENANT_ID = "test-tenant-id"; + private static final String MOCK_ACCESS_TOKEN = "mock_access_token_from_aks_proxy"; + + @TempDir + Path tempDir; + + private NioEventLoopGroup bossGroup; + private NioEventLoopGroup workerGroup; + private Channel serverChannel; + private String serverBaseUrl; + private Path tokenFilePath; + private Path caCertFilePath; + private String caCertPemData; + private AtomicInteger tokenRequestCount; + private AtomicInteger metadataRequestCount; + + @BeforeEach + public void setUp() throws Exception { + tokenRequestCount = new AtomicInteger(0); + metadataRequestCount = new AtomicInteger(0); + + // Generate key pair and certificate + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + KeyPair keyPair = keyPairGenerator.generateKeyPair(); + + // Generate self-signed certificate with both localhost and AKS SNI name + X509Certificate certificate = generateCertificateWithAksSni(keyPair, AKS_SNI_NAME); + + // Convert certificate to PEM format + caCertPemData = toPemFormat(certificate); + + // Write CA certificate to file + caCertFilePath = tempDir.resolve("ca-cert.pem"); + Files.write(caCertFilePath, caCertPemData.getBytes(StandardCharsets.UTF_8)); + + // Create mock federated token file + tokenFilePath = tempDir.resolve("token.jwt"); + String mockJwt = createMockFederatedToken(); + Files.write(tokenFilePath, mockJwt.getBytes(StandardCharsets.UTF_8)); + + // Start Netty HTTPS server + startNettyHttpsServer(keyPair.getPrivate(), certificate); + } + + @AfterEach + public void cleanup() { + if (serverChannel != null) { + serverChannel.close().syncUninterruptibly(); + } + if (bossGroup != null) { + bossGroup.shutdownGracefully(); + } + if (workerGroup != null) { + workerGroup.shutdownGracefully(); + } + } + + @Test + public void testAksProxyWithCaFile() { + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + AccessToken token = credential.getTokenSync(request); + + assertNotNull(token, "Access token should not be null"); + assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock response"); + assertTrue(token.getExpiresAt().isAfter(OffsetDateTime.now()), "Token should not be expired"); + assertEquals(1, tokenRequestCount.get(), "Server should have received exactly one token request"); + } + + @Test + public void testAksProxyWithCaData() { + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_DATA", caCertPemData) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + AccessToken token = credential.getTokenSync(request); + + assertNotNull(token, "Access token should not be null"); + assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock response"); + assertEquals(1, tokenRequestCount.get(), "Server should have received exactly one token request"); + } + + @Test + public void testAksProxyWithCaFileButNoSni() { + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + AccessToken token = credential.getTokenSync(request); + + assertNotNull(token, "Access token should not be null"); + assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock response"); + assertEquals(1, tokenRequestCount.get(), "Server should have received exactly one token request"); + } + + @Test + public void testAksProxyWithInvalidTokenFile() { + Path nonExistentTokenFile = tempDir.resolve("non-existent-token.jwt"); + + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", nonExistentTokenFile.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(nonExistentTokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + Exception exception = assertThrows(Exception.class, () -> credential.getTokenSync(request)); + + assertNotNull(exception, "Should throw exception when token file doesn't exist"); + assertEquals(0, tokenRequestCount.get(), "No token request should reach the server with invalid token file"); + } + + @Test + public void testAksProxyWithInvalidCaCertificate() throws Exception { + String invalidCertData = "-----BEGIN CERTIFICATE-----\nINVALID_BASE64_DATA\n-----END CERTIFICATE-----"; + Path invalidCaCertFile = tempDir.resolve("invalid-ca.pem"); + Files.write(invalidCaCertFile, invalidCertData.getBytes(StandardCharsets.UTF_8)); + + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_FILE", invalidCaCertFile.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + Exception exception = assertThrows(Exception.class, () -> { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + credential.getTokenSync(request); + }); + + assertNotNull(exception, "Should throw exception when CA certificate is invalid"); + assertEquals(0, tokenRequestCount.get(), "No token request should succeed with invalid CA certificate"); + } + + @Test + public void testAksProxyWithHttpScheme() { + String httpProxyUrl = serverBaseUrl.replace("https://", "http://"); + + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", httpProxyUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + Exception exception = assertThrows(Exception.class, () -> { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + credential.getTokenSync(request); + }); + + assertNotNull(exception, "Should throw exception when proxy URL uses HTTP instead of HTTPS"); + assertEquals(0, tokenRequestCount.get(), "No token request should be made with HTTP proxy URL"); + } + + @Test + public void testAksProxyWithMalformedUrl() { + String malformedUrl = "not-a-valid-url-at-all"; + + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", malformedUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + Exception exception = assertThrows(Exception.class, () -> { + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + credential.getTokenSync(request); + }); + + assertNotNull(exception, "Should throw exception when proxy URL is malformed"); + assertEquals(0, tokenRequestCount.get(), "No token request should be made with malformed proxy URL"); + } + + @Test + public void testAksProxyUnreachable() { + String unreachableProxyUrl = "https://localhost:19999"; + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", unreachableProxyUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + Exception exception = assertThrows(Exception.class, () -> credential.getTokenSync(request)); + + assertNotNull(exception, "Should throw exception when proxy server is unreachable"); + assertEquals(0, tokenRequestCount.get(), "No token request should succeed when proxy is unreachable"); + } + + @Test + public void testAksProxyWithEmptyTokenFile() throws Exception { + Path emptyTokenFile = tempDir.resolve("empty-token.jwt"); + Files.write(emptyTokenFile, new byte[0]); + + Configuration configuration = TestUtils + .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", emptyTokenFile.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(emptyTokenFile.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + Exception exception = assertThrows(Exception.class, () -> credential.getTokenSync(request)); + + assertNotNull(exception, "Should throw exception when token file is empty"); + assertEquals(0, tokenRequestCount.get(), "No token request should be made with empty token file"); + } + + @Test + public void testAksProxyWithUrlEncodedCharactersInPath() throws Exception { + String encodedPath = "/api%2Fv1/token"; + String proxyUrlWithEncoding = serverBaseUrl + encodedPath; + + Configuration configuration = TestUtils.createTestConfiguration( + new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", proxyUrlWithEncoding) + .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString()) + .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME) + .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID) + .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID) + .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString())); + + WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID) + .clientId(TEST_CLIENT_ID) + .tokenFilePath(tokenFilePath.toString()) + .configuration(configuration) + .enableKubernetesTokenProxy() + .disableInstanceDiscovery() + .build(); + + TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default"); + + AccessToken token = credential.getTokenSync(request); + + assertNotNull(token, "Token should not be null"); + assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock token"); + assertTrue(token.getExpiresAt().isAfter(OffsetDateTime.now()), "Token should not be expired"); + assertEquals(1, tokenRequestCount.get(), "Token request should be made once"); + } + + private X509Certificate generateCertificateWithAksSni(KeyPair keyPair, String aksSniName) throws Exception { + if (Security.getProvider("BC") == null) { + Security.addProvider(new BouncyCastleProvider()); + } + + long now = System.currentTimeMillis(); + Date notBefore = new Date(now - 1000L * 60 * 60); + Date notAfter = new Date(now + 1000L * 60 * 60 * 24 * 365); + BigInteger serial = BigInteger.valueOf(now); + + X500Name subject = new X500Name("CN=AKS-Proxy-Test"); + SubjectPublicKeyInfo publicKeyInfo = SubjectPublicKeyInfo.getInstance(keyPair.getPublic().getEncoded()); + + X509v3CertificateBuilder certBuilder + = new X509v3CertificateBuilder(subject, serial, notBefore, notAfter, subject, publicKeyInfo); + + GeneralName[] sans = new GeneralName[] { + new GeneralName(GeneralName.dNSName, "localhost"), + new GeneralName(GeneralName.dNSName, new org.bouncycastle.asn1.DERIA5String(aksSniName)) }; + GeneralNames subjectAltNames = new GeneralNames(sans); + certBuilder.addExtension(Extension.subjectAlternativeName, false, subjectAltNames); + + ContentSigner signer + = new JcaContentSignerBuilder("SHA256withRSA").setProvider("BC").build(keyPair.getPrivate()); + + X509CertificateHolder certHolder = certBuilder.build(signer); + + CertificateFactory certFactory = CertificateFactory.getInstance("X.509"); + try (InputStream in = new ByteArrayInputStream(certHolder.getEncoded())) { + return (X509Certificate) certFactory.generateCertificate(in); + } + } + + private void startNettyHttpsServer(PrivateKey privateKey, X509Certificate certificate) throws Exception { + SslContext sslContext + = SslContextBuilder.forServer(privateKey, certificate).sslProvider(SslProvider.OPENSSL).build(); + + bossGroup = new NioEventLoopGroup(1); + workerGroup = new NioEventLoopGroup(); + + ServerBootstrap bootstrap = new ServerBootstrap(); + bootstrap.group(bossGroup, workerGroup) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ChannelPipeline pipeline = ch.pipeline(); + + SslHandler sslHandler = sslContext.newHandler(ch.alloc()); + + sslHandler.handshakeFuture().addListener(future -> { + if (future.isSuccess()) { + SSLEngine engine = sslHandler.engine(); + System.out.println("OpenSSL handshake successful."); + } + }); + + pipeline.addLast(sslHandler); + pipeline.addLast(new HttpServerCodec()); + pipeline.addLast(new HttpObjectAggregator(65536)); + pipeline.addLast(new HttpRequestHandler()); + } + }); + + serverChannel = bootstrap.bind("localhost", 0).sync().channel(); + InetSocketAddress socketAddress = (InetSocketAddress) serverChannel.localAddress(); + int port = socketAddress.getPort(); + serverBaseUrl = "https://localhost:" + port; + + System.out.println("Netty HTTPS server (OpenSSL) started at: " + serverBaseUrl); + } + + private class HttpRequestHandler extends SimpleChannelInboundHandler { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) { + String uri = request.uri(); + String responseBody; + + if (uri.contains("/.well-known/openid-configuration")) { + metadataRequestCount.incrementAndGet(); + String base = serverBaseUrl + "/" + TEST_TENANT_ID; + responseBody = String + .format("{%n" + " \"token_endpoint\": \"%s/oauth2/v2.0/token\",%n" + " \"issuer\": \"%s/v2.0\",%n" + + " \"authorization_endpoint\": \"%s/oauth2/v2.0/authorize\"%n" + "}", base, base, base); + } else { + tokenRequestCount.incrementAndGet(); + responseBody = String.format( + "{\"token_type\":\"Bearer\",\"expires_in\":3600,\"ext_expires_in\":3600,\"access_token\":\"%s\"}", + MOCK_ACCESS_TOKEN); + } + + ByteBuf content = Unpooled.copiedBuffer(responseBody, CharsetUtil.UTF_8); + FullHttpResponse response + = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, content); + response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json"); + response.headers().set(HttpHeaderNames.CONTENT_LENGTH, content.readableBytes()); + + ctx.writeAndFlush(response); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + ctx.close(); + } + } + + private String toPemFormat(X509Certificate certificate) throws Exception { + Base64.Encoder encoder = Base64.getMimeEncoder(64, "\n".getBytes(StandardCharsets.UTF_8)); + byte[] encoded = encoder.encode(certificate.getEncoded()); + return "-----BEGIN CERTIFICATE-----\n" + new String(encoded, StandardCharsets.UTF_8) + + "\n-----END CERTIFICATE-----\n"; + } + + private String createMockFederatedToken() { + String header = Base64.getUrlEncoder() + .withoutPadding() + .encodeToString("{\"alg\":\"RS256\",\"typ\":\"JWT\"}".getBytes(StandardCharsets.UTF_8)); + + long exp = System.currentTimeMillis() / 1000 + 3600; + String payload = Base64.getUrlEncoder() + .withoutPadding() + .encodeToString(String.format( + "{\"aud\":\"api://AzureADTokenExchange\",\"exp\":%d,\"iss\":\"kubernetes.io/serviceaccount\",\"sub\":\"system:serviceaccount:default:workload-identity-sa\"}", + exp).getBytes(StandardCharsets.UTF_8)); + + String signature + = Base64.getUrlEncoder().withoutPadding().encodeToString("mock-signature".getBytes(StandardCharsets.UTF_8)); + + return header + "." + payload + "." + signature; + } +}