diff --git a/sdk/identity/azure-identity/pom.xml b/sdk/identity/azure-identity/pom.xml
index 82dd64b7a0f5..a3eb2cc64d4a 100644
--- a/sdk/identity/azure-identity/pom.xml
+++ b/sdk/identity/azure-identity/pom.xml
@@ -115,6 +115,13 @@
1.17.7
test
+
+
+ org.bouncycastle
+ bcpkix-lts8on
+ 2.73.8
+ test
+
diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java
index a0c04ff2d952..40c3ac8ce4b4 100644
--- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java
+++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredential.java
@@ -10,6 +10,9 @@
import com.azure.core.util.CoreUtils;
import com.azure.core.util.logging.ClientLogger;
import com.azure.identity.implementation.IdentityClientOptions;
+import com.azure.identity.implementation.customtokenproxy.CustomTokenProxyConfiguration;
+import com.azure.identity.implementation.customtokenproxy.CustomTokenProxyHttpClient;
+import com.azure.identity.implementation.customtokenproxy.ProxyConfig;
import com.azure.identity.implementation.util.LoggingUtil;
import com.azure.identity.implementation.util.ValidationUtil;
import reactor.core.publisher.Mono;
@@ -89,6 +92,13 @@ public class WorkloadIdentityCredential implements TokenCredential {
ClientAssertionCredential tempClientAssertionCredential = null;
String tempClientId = null;
+ if (identityClientOptions.isKubernetesTokenProxyEnabled()) {
+ if (CustomTokenProxyConfiguration.isConfigured(configuration)) {
+ ProxyConfig proxyConfig = CustomTokenProxyConfiguration.parseAndValidate(configuration);
+ identityClientOptions.setHttpClient(new CustomTokenProxyHttpClient(proxyConfig));
+ }
+ }
+
if (!(CoreUtils.isNullOrEmpty(tenantIdInput)
|| CoreUtils.isNullOrEmpty(federatedTokenFilePathInput)
|| CoreUtils.isNullOrEmpty(clientIdInput)
diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java
index 95f16fd9628d..31e62c6f4e14 100644
--- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java
+++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/WorkloadIdentityCredentialBuilder.java
@@ -47,6 +47,7 @@
public class WorkloadIdentityCredentialBuilder extends AadCredentialBuilderBase {
private static final ClientLogger LOGGER = new ClientLogger(WorkloadIdentityCredentialBuilder.class);
private String tokenFilePath;
+ private boolean enableTokenProxy = false;
/**
* Creates an instance of a WorkloadIdentityCredentialBuilder.
@@ -66,6 +67,19 @@ public WorkloadIdentityCredentialBuilder tokenFilePath(String tokenFilePath) {
return this;
}
+ /**
+ * Enables the Kubernetes token proxy feature for AKS workload identity scenarios.
+ * When enabled, the credential will attempt to use a custom token proxy configured through
+ * environment variables (AZURE_KUBERNETES_TOKEN_PROXY, AZURE_KUBERNETES_CA_FILE,
+ * AZURE_KUBERNETES_CA_DATA, AZURE_KUBERNETES_SNI_NAME).
+ *
+ * @return An updated instance of this builder with Kubernetes token proxy enabled.
+ */
+ public WorkloadIdentityCredentialBuilder enableKubernetesTokenProxy() {
+ this.enableTokenProxy = true;
+ return this;
+ }
+
/**
* Creates new {@link WorkloadIdentityCredential} with the configured options set.
*
@@ -88,6 +102,8 @@ public WorkloadIdentityCredential build() {
ValidationUtil.validate(this.getClass().getSimpleName(), LOGGER, "Client ID", clientIdInput, "Tenant ID",
tenantIdInput, "Service Token File Path", federatedTokenFilePathInput);
+ identityClientOptions.setEnableKubernetesTokenProxy(this.enableTokenProxy);
+
return new WorkloadIdentityCredential(tenantIdInput, clientIdInput, federatedTokenFilePathInput,
identityClientOptions.clone());
}
diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java
index d934c23c1d67..3c2d21d08c56 100644
--- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java
+++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/IdentityClientOptions.java
@@ -71,6 +71,7 @@ public final class IdentityClientOptions implements Cloneable {
private List perRetryPolicies;
private boolean instanceDiscovery;
private String dacEnvConfiguredCredential;
+ private boolean enableKubernetesTokenProxy;
private Duration credentialProcessTimeout = Duration.ofSeconds(10);
@@ -833,6 +834,15 @@ public String getDACEnvConfiguredCredential() {
return dacEnvConfiguredCredential;
}
+ public boolean isKubernetesTokenProxyEnabled() {
+ return enableKubernetesTokenProxy;
+ }
+
+ public IdentityClientOptions setEnableKubernetesTokenProxy(boolean enableTokenProxy) {
+ this.enableKubernetesTokenProxy = enableTokenProxy;
+ return this;
+ }
+
public IdentityClientOptions clone() {
IdentityClientOptions clone
= new IdentityClientOptions().setAdditionallyAllowedTenants(this.additionallyAllowedTenants)
@@ -863,7 +873,8 @@ public IdentityClientOptions clone() {
.setPerRetryPolicies(this.perRetryPolicies)
.setBrowserCustomizationOptions(this.browserCustomizationOptions)
.setChained(this.isChained)
- .subscription(this.subscription);
+ .subscription(this.subscription)
+ .setEnableKubernetesTokenProxy(this.enableKubernetesTokenProxy);
if (isBrokerEnabled()) {
clone.setBrokerWindowHandle(this.brokerWindowHandle);
diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java
new file mode 100644
index 000000000000..da7eddc1e38b
--- /dev/null
+++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyConfiguration.java
@@ -0,0 +1,112 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.identity.implementation.customtokenproxy;
+
+import com.azure.core.util.logging.ClientLogger;
+
+import java.net.URI;
+import java.net.URL;
+import java.nio.charset.StandardCharsets;
+import java.net.URISyntaxException;
+
+import com.azure.core.util.Configuration;
+import com.azure.core.util.CoreUtils;
+
+public final class CustomTokenProxyConfiguration {
+
+ private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyConfiguration.class);
+
+ public static final String AZURE_KUBERNETES_TOKEN_PROXY = "AZURE_KUBERNETES_TOKEN_PROXY";
+ public static final String AZURE_KUBERNETES_CA_FILE = "AZURE_KUBERNETES_CA_FILE";
+ public static final String AZURE_KUBERNETES_CA_DATA = "AZURE_KUBERNETES_CA_DATA";
+ public static final String AZURE_KUBERNETES_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME";
+
+ private CustomTokenProxyConfiguration() {
+ }
+
+ public static boolean isConfigured(Configuration configuration) {
+ String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY);
+ return !CoreUtils.isNullOrEmpty(tokenProxyUrl);
+ }
+
+ public static ProxyConfig parseAndValidate(Configuration configuration) {
+ String tokenProxyUrl = configuration.get(AZURE_KUBERNETES_TOKEN_PROXY);
+ String caFile = configuration.get(AZURE_KUBERNETES_CA_FILE);
+ String caData = configuration.get(AZURE_KUBERNETES_CA_DATA);
+ String sniName = configuration.get(AZURE_KUBERNETES_SNI_NAME);
+
+ if (CoreUtils.isNullOrEmpty(tokenProxyUrl)) {
+ if (!CoreUtils.isNullOrEmpty(sniName)
+ || !CoreUtils.isNullOrEmpty(caFile)
+ || !CoreUtils.isNullOrEmpty(caData)) {
+ throw LOGGER.logExceptionAsError(new IllegalArgumentException(
+ "AZURE_KUBERNETES_TOKEN_PROXY is not set but other custom endpoint-related environment variables are present"));
+ }
+ return null;
+ }
+
+ if (!CoreUtils.isNullOrEmpty(caFile) && !CoreUtils.isNullOrEmpty(caData)) {
+ throw LOGGER.logExceptionAsError(new IllegalArgumentException(
+ "Only one of AZURE_KUBERNETES_CA_FILE or AZURE_KUBERNETES_CA_DATA can be set."));
+ }
+
+ URL proxyUrl = validateProxyUrl(tokenProxyUrl);
+
+ byte[] caCertBytes = null;
+ if (!CoreUtils.isNullOrEmpty(caData)) {
+ try {
+ caCertBytes = caData.getBytes(StandardCharsets.UTF_8);
+ } catch (Exception e) {
+ throw LOGGER.logExceptionAsError(new IllegalArgumentException(
+ "Failed to decode CA certificate data from AZURE_KUBERNETES_CA_DATA", e));
+ }
+ }
+
+ ProxyConfig config = new ProxyConfig(proxyUrl, sniName, caFile, caCertBytes);
+ return config;
+ }
+
+ private static URL validateProxyUrl(String endpoint) {
+ if (CoreUtils.isNullOrEmpty(endpoint)) {
+ throw LOGGER.logExceptionAsError(new IllegalArgumentException("Proxy endpoint cannot be null or empty"));
+ }
+
+ try {
+ URI tokenProxy = new URI(endpoint);
+
+ if (!"https".equals(tokenProxy.getScheme())) {
+ throw LOGGER.logExceptionAsError(new IllegalArgumentException(
+ "Custom token endpoint must use https scheme, got: " + tokenProxy.getScheme()));
+ }
+
+ if (tokenProxy.getRawUserInfo() != null) {
+ throw LOGGER.logExceptionAsError(
+ new IllegalArgumentException("Custom token endpoint URL must not contain user info: " + endpoint));
+ }
+
+ if (tokenProxy.getRawQuery() != null) {
+ throw LOGGER.logExceptionAsError(
+ new IllegalArgumentException("Custom token endpoint URL must not contain a query: " + endpoint));
+ }
+
+ if (tokenProxy.getRawFragment() != null) {
+ throw LOGGER.logExceptionAsError(
+ new IllegalArgumentException("Custom token endpoint URL must not contain a fragment: " + endpoint));
+ }
+
+ if (tokenProxy.getRawPath() == null || tokenProxy.getRawPath().isEmpty()) {
+ tokenProxy = new URI(tokenProxy.getScheme(), null, tokenProxy.getHost(), tokenProxy.getPort(), "/",
+ null, null);
+ }
+
+ return tokenProxy.toURL();
+
+ } catch (URISyntaxException | IllegalArgumentException e) {
+ throw LOGGER.logExceptionAsError(new IllegalArgumentException("Failed to normalize proxy URL path", e));
+ } catch (Exception e) {
+ throw new RuntimeException("Unexpected error while validating proxy URL: " + endpoint, e);
+ }
+ }
+
+}
diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java
new file mode 100644
index 000000000000..4cad1ba381be
--- /dev/null
+++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpClient.java
@@ -0,0 +1,248 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.identity.implementation.customtokenproxy;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.HttpURLConnection;
+import java.net.MalformedURLException;
+import java.net.URI;
+import java.net.URL;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.security.KeyStore;
+import java.security.cert.CertificateException;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import javax.net.ssl.HostnameVerifier;
+import javax.net.ssl.HttpsURLConnection;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocketFactory;
+import javax.net.ssl.TrustManagerFactory;
+
+import com.azure.core.http.HttpClient;
+import com.azure.core.http.HttpRequest;
+import com.azure.core.http.HttpResponse;
+import com.azure.core.util.Context;
+import com.azure.core.util.CoreUtils;
+import com.azure.core.util.logging.ClientLogger;
+import com.azure.identity.implementation.util.IdentitySslUtil;
+
+import reactor.core.publisher.Mono;
+
+public class CustomTokenProxyHttpClient implements HttpClient {
+
+ private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyHttpClient.class);
+
+ private final ProxyConfig proxyConfig;
+ private volatile SSLContext cachedSSLContext;
+ private volatile byte[] cachedFileContent;
+ private final URL proxyUrl;
+ private final String sniName;
+ private final byte[] caData;
+ private final String caFile;
+
+ public CustomTokenProxyHttpClient(ProxyConfig proxyConfig) {
+ this.proxyConfig = proxyConfig;
+ this.proxyUrl = proxyConfig.getTokenProxyUrl();
+ this.sniName = proxyConfig.getSniName();
+ this.caData = proxyConfig.getCaData();
+ this.caFile = proxyConfig.getCaFile();
+ }
+
+ @Override
+ public Mono send(HttpRequest request) {
+ return Mono.fromCallable(() -> sendSync(request, Context.NONE));
+ }
+
+ @Override
+ public HttpResponse sendSync(HttpRequest request, Context context) {
+ try {
+ HttpURLConnection connection = createConnection(request);
+ return new CustomTokenProxyHttpResponse(request, connection);
+ } catch (IOException e) {
+ throw LOGGER.logExceptionAsError(new RuntimeException("Failed to create connection to token proxy", e));
+ }
+ }
+
+ private HttpURLConnection createConnection(HttpRequest request) throws IOException {
+ URL updatedUrl = rewriteTokenRequestForProxy(request.getUrl());
+ HttpsURLConnection connection = (HttpsURLConnection) updatedUrl.openConnection();
+ // If SNI explicitly provided
+ try {
+ SSLContext sslContext = getSSLContext();
+ SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
+ if (!CoreUtils.isNullOrEmpty(sniName)) {
+ sslSocketFactory = new IdentitySslUtil.SniSslSocketFactory(sslSocketFactory, sniName);
+ }
+ connection.setSSLSocketFactory(sslSocketFactory);
+ connection.setHostnameVerifier(sniAwareVerifier(sniName, proxyUrl));
+ } catch (Exception e) {
+ throw LOGGER.logExceptionAsError(new RuntimeException("Failed to set up SSL context for token proxy", e));
+ }
+
+ String method = request.getHttpMethod().toString();
+ connection.setRequestMethod(method);
+ connection.setInstanceFollowRedirects(false);
+ connection.setConnectTimeout(10_000);
+ connection.setReadTimeout(20_000);
+
+ boolean hasBody = request.getBodyAsBinaryData() != null && request.getBodyAsBinaryData().toBytes() != null;
+
+ if (hasBody && ("POST".equals(method) || "PUT".equals(method) || "PATCH".equals(method))) {
+ connection.setDoOutput(true);
+ } else {
+ connection.setDoOutput(false);
+ }
+
+ request.getHeaders().forEach(header -> {
+ connection.addRequestProperty(header.getName(), header.getValue());
+ });
+
+ if (request.getBodyAsBinaryData() != null) {
+ byte[] bytes = request.getBodyAsBinaryData().toBytes();
+ if (bytes != null && bytes.length > 0) {
+ try (OutputStream os = connection.getOutputStream()) {
+ os.write(bytes);
+ os.flush();
+ }
+ }
+ }
+ return connection;
+ }
+
+ private URL rewriteTokenRequestForProxy(URL originalUrl) throws MalformedURLException {
+ try {
+ String originalPath = originalUrl.getPath();
+ String originalQuery = originalUrl.getQuery();
+
+ String tokenProxyBase = proxyUrl.toString();
+ if (!tokenProxyBase.endsWith("/")) {
+ tokenProxyBase += "/";
+ }
+
+ URI combined = URI.create(tokenProxyBase)
+ .resolve(originalPath.startsWith("/") ? originalPath.substring(1) : originalPath);
+
+ String combinedStr = combined.toString();
+ if (originalQuery != null && !originalQuery.isEmpty()) {
+ combinedStr += "?" + originalQuery;
+ }
+
+ return new URL(combinedStr);
+
+ } catch (Exception e) {
+ throw LOGGER.logExceptionAsError(new RuntimeException("Failed to rewrite token request for proxy", e));
+ }
+ }
+
+ private SSLContext getSSLContext() {
+ try {
+ // If no CA override provided, use default
+ if (CoreUtils.isNullOrEmpty(caFile) && (caData == null || caData.length == 0)) {
+ synchronized (this) {
+ if (cachedSSLContext == null) {
+ cachedSSLContext = SSLContext.getDefault();
+ }
+ }
+ return cachedSSLContext;
+ }
+
+ // If CA data provided, use it
+ if (CoreUtils.isNullOrEmpty(caFile)) {
+ synchronized (this) {
+ if (cachedSSLContext == null) {
+ cachedSSLContext = createSslContextFromBytes(caData);
+ }
+ }
+ return cachedSSLContext;
+ }
+
+ // If CA file provided, read it (and re-read if it changes)
+ Path path = Paths.get(caFile);
+ if (!Files.exists(path)) {
+ throw LOGGER.logExceptionAsError(new RuntimeException("CA file not found: " + caFile));
+ }
+
+ byte[] currentContent = Files.readAllBytes(path);
+
+ synchronized (this) {
+ if (currentContent.length == 0) {
+ throw LOGGER.logExceptionAsError(new RuntimeException("CA file " + caFile + " is empty"));
+ }
+
+ if (cachedSSLContext == null || !Arrays.equals(currentContent, cachedFileContent)) {
+ cachedSSLContext = createSslContextFromBytes(currentContent);
+ cachedFileContent = currentContent;
+ }
+ }
+ return cachedSSLContext;
+
+ } catch (Exception e) {
+ throw LOGGER.logExceptionAsError(new RuntimeException("Failed to initialize SSLContext for proxy", e));
+ }
+ }
+
+ // Create SSLContext from byte array containing PEM certificate data
+ private SSLContext createSslContextFromBytes(byte[] certificateData) {
+ try (InputStream inputStream = new ByteArrayInputStream(certificateData)) {
+ CertificateFactory cf = CertificateFactory.getInstance("X.509");
+
+ List certificates = new ArrayList<>();
+ while (true) {
+ try {
+ X509Certificate cert = (X509Certificate) cf.generateCertificate(inputStream);
+ certificates.add(cert);
+ } catch (CertificateException e) {
+ break;
+ }
+ }
+
+ if (certificates.isEmpty()) {
+ throw LOGGER.logExceptionAsError(new RuntimeException("No valid certificates found"));
+ }
+ return createSslContext(certificates);
+ } catch (Exception e) {
+ throw LOGGER.logExceptionAsError(new RuntimeException("Failed to create SSLContext from bytes", e));
+ }
+ }
+
+ // Create SSLContext from a single X509Certificate
+ private SSLContext createSslContext(List certificates) {
+ try {
+ KeyStore keystore = KeyStore.getInstance(KeyStore.getDefaultType());
+ keystore.load(null, null);
+ int index = 1;
+ for (X509Certificate caCert : certificates) {
+ keystore.setCertificateEntry("ca-cert-" + index, caCert);
+ index++;
+ }
+ TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+ tmf.init(keystore);
+
+ SSLContext context = SSLContext.getInstance("TLS");
+ context.init(null, tmf.getTrustManagers(), null);
+ return context;
+ } catch (Exception e) {
+ throw LOGGER.logExceptionAsError(new RuntimeException("Failed to create SSLContext", e));
+ }
+ }
+
+ private static HostnameVerifier sniAwareVerifier(String sniName, URL customProxyUrl) {
+ return (urlHost, session) -> {
+ String peerHost = session.getPeerHost();
+ String expectedProxyHost = customProxyUrl.getHost();
+ return peerHost.equalsIgnoreCase(expectedProxyHost)
+ || (!CoreUtils.isNullOrEmpty(sniName) && peerHost.equalsIgnoreCase(sniName));
+ };
+ }
+
+}
diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java
new file mode 100644
index 000000000000..8eb14cea2764
--- /dev/null
+++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/CustomTokenProxyHttpResponse.java
@@ -0,0 +1,136 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.identity.implementation.customtokenproxy;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.HttpURLConnection;
+import java.nio.ByteBuffer;
+import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Map;
+
+import com.azure.core.http.HttpHeaderName;
+import com.azure.core.http.HttpHeaders;
+import com.azure.core.http.HttpRequest;
+import com.azure.core.http.HttpResponse;
+import com.azure.core.util.logging.ClientLogger;
+import reactor.core.publisher.Flux;
+import reactor.core.publisher.Mono;
+
+public final class CustomTokenProxyHttpResponse extends HttpResponse {
+
+ private static final ClientLogger LOGGER = new ClientLogger(CustomTokenProxyHttpResponse.class);
+
+ // private final HttpRequest request;
+ private final int statusCode;
+ private final HttpHeaders headers;
+ private final HttpURLConnection connection;
+ private byte[] cachedResponseBodyBytes;
+
+ public CustomTokenProxyHttpResponse(HttpRequest request, HttpURLConnection connection) {
+ super(request);
+ this.connection = connection;
+ this.statusCode = extractStatusCode(connection);
+ this.headers = extractHeaders(connection);
+ }
+
+ private HttpHeaders extractHeaders(HttpURLConnection connection) {
+ HttpHeaders headers = new HttpHeaders();
+ for (Map.Entry> entry : connection.getHeaderFields().entrySet()) {
+ String headerName = entry.getKey();
+ if (headerName != null) {
+ for (String headerValue : entry.getValue()) {
+ headers.add(headerName, headerValue);
+ }
+ }
+ }
+ return headers;
+ }
+
+ private int extractStatusCode(HttpURLConnection connection) {
+ try {
+ return connection.getResponseCode();
+ } catch (IOException e) {
+ throw LOGGER
+ .logExceptionAsError(new RuntimeException("Failed to get status code from token proxy response", e));
+ }
+ }
+
+ @Override
+ public int getStatusCode() {
+ return statusCode;
+ }
+
+ @Override
+ public String getHeaderValue(String name) {
+ return headers.getValue(HttpHeaderName.fromString(name));
+ }
+
+ @Override
+ public HttpHeaders getHeaders() {
+ return headers;
+ }
+
+ @Override
+ public Mono getBodyAsByteArray() {
+ return Mono.fromCallable(() -> {
+ if (cachedResponseBodyBytes != null) {
+ return cachedResponseBodyBytes;
+ }
+
+ InputStream stream = null;
+ try {
+ stream = getResponseStream();
+ if (stream == null) {
+ cachedResponseBodyBytes = new byte[0];
+ return cachedResponseBodyBytes;
+ }
+ ByteArrayOutputStream buffer = new ByteArrayOutputStream();
+ int n;
+ byte[] temp = new byte[4096];
+ while ((n = stream.read(temp)) != -1) {
+ buffer.write(temp, 0, n);
+ }
+ cachedResponseBodyBytes = buffer.toByteArray();
+ return cachedResponseBodyBytes;
+ } finally {
+ if (stream != null) {
+ stream.close();
+ }
+ }
+ });
+ }
+
+ @Override
+ public Flux getBody() {
+ return getBodyAsByteArray().flatMapMany(bytes -> Flux.just(ByteBuffer.wrap(bytes)));
+ }
+
+ @Override
+ public Mono getBodyAsString() {
+ return getBodyAsString(StandardCharsets.UTF_8);
+ }
+
+ @Override
+ public Mono getBodyAsString(Charset charset) {
+ return getBodyAsByteArray().map(bytes -> new String(bytes, charset));
+ }
+
+ @Override
+ public void close() {
+ connection.disconnect();
+ }
+
+ private InputStream getResponseStream() throws IOException {
+ try {
+ return connection.getInputStream();
+ } catch (IOException e) {
+ return connection.getErrorStream();
+ }
+ }
+
+}
diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java
new file mode 100644
index 000000000000..171087bac6e9
--- /dev/null
+++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/customtokenproxy/ProxyConfig.java
@@ -0,0 +1,36 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.identity.implementation.customtokenproxy;
+
+import java.net.URL;
+
+public class ProxyConfig {
+ private final URL tokenProxyUrl;
+ private final String sniName;
+ private final String caFile;
+ private final byte[] caData;
+
+ public ProxyConfig(URL tokenProxyUrl, String sniName, String caFile, byte[] caData) {
+ this.tokenProxyUrl = tokenProxyUrl;
+ this.sniName = sniName;
+ this.caFile = caFile;
+ this.caData = caData;
+ }
+
+ public URL getTokenProxyUrl() {
+ return tokenProxyUrl;
+ }
+
+ public String getSniName() {
+ return sniName;
+ }
+
+ public String getCaFile() {
+ return caFile;
+ }
+
+ public byte[] getCaData() {
+ return caData;
+ }
+}
diff --git a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java
index ba932d5e5a55..5f6b7161cd69 100644
--- a/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java
+++ b/sdk/identity/azure-identity/src/main/java/com/azure/identity/implementation/util/IdentitySslUtil.java
@@ -3,15 +3,23 @@
package com.azure.identity.implementation.util;
+import com.azure.core.util.CoreUtils;
import com.azure.core.util.logging.ClientLogger;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
+import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSession;
+import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
+
+import java.io.IOException;
+import java.net.Socket;
+import java.nio.charset.StandardCharsets;
import java.security.KeyManagementException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
@@ -19,6 +27,7 @@
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
+import java.util.Collections;
public final class IdentitySslUtil {
public static final HostnameVerifier ALL_HOSTS_ACCEPT_HOSTNAME_VERIFIER;
@@ -125,4 +134,76 @@ private static String extractCertificateThumbprint(Certificate certificate, Clie
throw logger.logExceptionAsError(new RuntimeException(e));
}
}
+
+ public static final class SniSslSocketFactory extends SSLSocketFactory {
+ private final SSLSocketFactory sslSocketFactory;
+ private final String sniName;
+
+ public SniSslSocketFactory(SSLSocketFactory sslSocketFactory, String sniName) {
+ this.sslSocketFactory = sslSocketFactory;
+ this.sniName = sniName;
+ }
+
+ @Override
+ public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException {
+ Socket sslSocket = (SSLSocket) sslSocketFactory.createSocket(s, host, port, autoClose);
+ configureSni(sslSocket);
+ return sslSocket;
+ }
+
+ @Override
+ public Socket createSocket(String host, int port) throws IOException {
+ Socket socket = sslSocketFactory.createSocket(host, port);
+ configureSni(socket);
+ return socket;
+ }
+
+ @Override
+ public Socket createSocket(String host, int port, java.net.InetAddress localAddress, int localPort)
+ throws IOException {
+ Socket socket = sslSocketFactory.createSocket(host, port, localAddress, localPort);
+ configureSni(socket);
+ return socket;
+ }
+
+ @Override
+ public Socket createSocket(java.net.InetAddress host, int port) throws IOException {
+ Socket socket = sslSocketFactory.createSocket(host, port);
+ configureSni(socket);
+ return socket;
+ }
+
+ @Override
+ public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress,
+ int localPort) throws IOException {
+ Socket socket = sslSocketFactory.createSocket(address, port, localAddress, localPort);
+ configureSni(socket);
+ return socket;
+ }
+
+ @Override
+ public String[] getDefaultCipherSuites() {
+ return sslSocketFactory.getDefaultCipherSuites();
+ }
+
+ @Override
+ public String[] getSupportedCipherSuites() {
+ return sslSocketFactory.getSupportedCipherSuites();
+ }
+
+ private void configureSni(Socket socket) {
+ if (socket instanceof SSLSocket && !CoreUtils.isNullOrEmpty(sniName)) {
+ SSLSocket sslSocket = (SSLSocket) socket;
+ SSLParameters sslParameters = sslSocket.getSSLParameters();
+ sslParameters.setServerNames(Collections.singletonList(new RawSniServerName(sniName)));
+ sslSocket.setSSLParameters(sslParameters);
+ }
+ }
+ }
+
+ public static final class RawSniServerName extends SNIServerName {
+ public RawSniServerName(String sniHost) {
+ super(0, sniHost.getBytes(StandardCharsets.UTF_8));
+ }
+ }
}
diff --git a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialIdentityBindingTest.java b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialIdentityBindingTest.java
new file mode 100644
index 000000000000..1c3cd5581efc
--- /dev/null
+++ b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialIdentityBindingTest.java
@@ -0,0 +1,555 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.identity;
+
+import com.azure.core.credential.AccessToken;
+import com.azure.core.credential.TokenRequestContext;
+import com.azure.core.test.utils.TestConfigurationSource;
+import com.azure.core.util.Configuration;
+import com.azure.identity.util.TestUtils;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.handler.codec.http.DefaultFullHttpResponse;
+import io.netty.handler.codec.http.FullHttpRequest;
+import io.netty.handler.codec.http.FullHttpResponse;
+import io.netty.handler.codec.http.HttpHeaderNames;
+import io.netty.handler.codec.http.HttpObjectAggregator;
+import io.netty.handler.codec.http.HttpResponseStatus;
+import io.netty.handler.codec.http.HttpServerCodec;
+import io.netty.handler.codec.http.HttpVersion;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.handler.ssl.SslProvider;
+import io.netty.util.CharsetUtil;
+import org.bouncycastle.asn1.x500.X500Name;
+import org.bouncycastle.asn1.x509.Extension;
+import org.bouncycastle.asn1.x509.GeneralName;
+import org.bouncycastle.asn1.x509.GeneralNames;
+import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
+import org.bouncycastle.cert.X509CertificateHolder;
+import org.bouncycastle.cert.X509v3CertificateBuilder;
+import org.bouncycastle.jce.provider.BouncyCastleProvider;
+import org.bouncycastle.operator.ContentSigner;
+import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import javax.net.ssl.SSLEngine;
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.math.BigInteger;
+import java.net.InetSocketAddress;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.PrivateKey;
+import java.security.Security;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+import java.time.OffsetDateTime;
+import java.util.Base64;
+import java.util.Date;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+/**
+ * E2E test for WorkloadIdentityCredential with AKS Custom Token Proxy using Netty mock server.
+ *
+ */
+public class WorkloadIdentityCredentialIdentityBindingTest {
+
+ private static final String AKS_SNI_NAME = "test-aks-proxy.ests.aks";
+
+ private static final String TEST_CLIENT_ID = "test-client-id";
+ private static final String TEST_TENANT_ID = "test-tenant-id";
+ private static final String MOCK_ACCESS_TOKEN = "mock_access_token_from_aks_proxy";
+
+ @TempDir
+ Path tempDir;
+
+ private NioEventLoopGroup bossGroup;
+ private NioEventLoopGroup workerGroup;
+ private Channel serverChannel;
+ private String serverBaseUrl;
+ private Path tokenFilePath;
+ private Path caCertFilePath;
+ private String caCertPemData;
+ private AtomicInteger tokenRequestCount;
+ private AtomicInteger metadataRequestCount;
+
+ @BeforeEach
+ public void setUp() throws Exception {
+ tokenRequestCount = new AtomicInteger(0);
+ metadataRequestCount = new AtomicInteger(0);
+
+ // Generate key pair and certificate
+ KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
+ keyPairGenerator.initialize(2048);
+ KeyPair keyPair = keyPairGenerator.generateKeyPair();
+
+ // Generate self-signed certificate with both localhost and AKS SNI name
+ X509Certificate certificate = generateCertificateWithAksSni(keyPair, AKS_SNI_NAME);
+
+ // Convert certificate to PEM format
+ caCertPemData = toPemFormat(certificate);
+
+ // Write CA certificate to file
+ caCertFilePath = tempDir.resolve("ca-cert.pem");
+ Files.write(caCertFilePath, caCertPemData.getBytes(StandardCharsets.UTF_8));
+
+ // Create mock federated token file
+ tokenFilePath = tempDir.resolve("token.jwt");
+ String mockJwt = createMockFederatedToken();
+ Files.write(tokenFilePath, mockJwt.getBytes(StandardCharsets.UTF_8));
+
+ // Start Netty HTTPS server
+ startNettyHttpsServer(keyPair.getPrivate(), certificate);
+ }
+
+ @AfterEach
+ public void cleanup() {
+ if (serverChannel != null) {
+ serverChannel.close().syncUninterruptibly();
+ }
+ if (bossGroup != null) {
+ bossGroup.shutdownGracefully();
+ }
+ if (workerGroup != null) {
+ workerGroup.shutdownGracefully();
+ }
+ }
+
+ @Test
+ public void testAksProxyWithCaFile() {
+ Configuration configuration = TestUtils
+ .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl)
+ .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString())
+ .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME)
+ .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID)
+ .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID)
+ .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString()));
+
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID)
+ .clientId(TEST_CLIENT_ID)
+ .tokenFilePath(tokenFilePath.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .disableInstanceDiscovery()
+ .build();
+
+ TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+
+ AccessToken token = credential.getTokenSync(request);
+
+ assertNotNull(token, "Access token should not be null");
+ assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock response");
+ assertTrue(token.getExpiresAt().isAfter(OffsetDateTime.now()), "Token should not be expired");
+ assertEquals(1, tokenRequestCount.get(), "Server should have received exactly one token request");
+ }
+
+ @Test
+ public void testAksProxyWithCaData() {
+ Configuration configuration = TestUtils
+ .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl)
+ .put("AZURE_KUBERNETES_CA_DATA", caCertPemData)
+ .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME)
+ .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID)
+ .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID)
+ .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString()));
+
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID)
+ .clientId(TEST_CLIENT_ID)
+ .tokenFilePath(tokenFilePath.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .disableInstanceDiscovery()
+ .build();
+
+ TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+
+ AccessToken token = credential.getTokenSync(request);
+
+ assertNotNull(token, "Access token should not be null");
+ assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock response");
+ assertEquals(1, tokenRequestCount.get(), "Server should have received exactly one token request");
+ }
+
+ @Test
+ public void testAksProxyWithCaFileButNoSni() {
+ Configuration configuration = TestUtils
+ .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl)
+ .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString())
+ .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID)
+ .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID)
+ .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString()));
+
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID)
+ .clientId(TEST_CLIENT_ID)
+ .tokenFilePath(tokenFilePath.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .disableInstanceDiscovery()
+ .build();
+
+ TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+
+ AccessToken token = credential.getTokenSync(request);
+
+ assertNotNull(token, "Access token should not be null");
+ assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock response");
+ assertEquals(1, tokenRequestCount.get(), "Server should have received exactly one token request");
+ }
+
+ @Test
+ public void testAksProxyWithInvalidTokenFile() {
+ Path nonExistentTokenFile = tempDir.resolve("non-existent-token.jwt");
+
+ Configuration configuration = TestUtils
+ .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl)
+ .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString())
+ .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME)
+ .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID)
+ .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID)
+ .put("AZURE_FEDERATED_TOKEN_FILE", nonExistentTokenFile.toString()));
+
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID)
+ .clientId(TEST_CLIENT_ID)
+ .tokenFilePath(nonExistentTokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .disableInstanceDiscovery()
+ .build();
+
+ TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+
+ Exception exception = assertThrows(Exception.class, () -> credential.getTokenSync(request));
+
+ assertNotNull(exception, "Should throw exception when token file doesn't exist");
+ assertEquals(0, tokenRequestCount.get(), "No token request should reach the server with invalid token file");
+ }
+
+ @Test
+ public void testAksProxyWithInvalidCaCertificate() throws Exception {
+ String invalidCertData = "-----BEGIN CERTIFICATE-----\nINVALID_BASE64_DATA\n-----END CERTIFICATE-----";
+ Path invalidCaCertFile = tempDir.resolve("invalid-ca.pem");
+ Files.write(invalidCaCertFile, invalidCertData.getBytes(StandardCharsets.UTF_8));
+
+ Configuration configuration = TestUtils
+ .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl)
+ .put("AZURE_KUBERNETES_CA_FILE", invalidCaCertFile.toString())
+ .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME)
+ .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID)
+ .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID)
+ .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString()));
+
+ Exception exception = assertThrows(Exception.class, () -> {
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID)
+ .clientId(TEST_CLIENT_ID)
+ .tokenFilePath(tokenFilePath.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .disableInstanceDiscovery()
+ .build();
+
+ TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+ credential.getTokenSync(request);
+ });
+
+ assertNotNull(exception, "Should throw exception when CA certificate is invalid");
+ assertEquals(0, tokenRequestCount.get(), "No token request should succeed with invalid CA certificate");
+ }
+
+ @Test
+ public void testAksProxyWithHttpScheme() {
+ String httpProxyUrl = serverBaseUrl.replace("https://", "http://");
+
+ Configuration configuration = TestUtils
+ .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", httpProxyUrl)
+ .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString())
+ .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME)
+ .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID)
+ .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID)
+ .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString()));
+
+ Exception exception = assertThrows(Exception.class, () -> {
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID)
+ .clientId(TEST_CLIENT_ID)
+ .tokenFilePath(tokenFilePath.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .disableInstanceDiscovery()
+ .build();
+
+ TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+ credential.getTokenSync(request);
+ });
+
+ assertNotNull(exception, "Should throw exception when proxy URL uses HTTP instead of HTTPS");
+ assertEquals(0, tokenRequestCount.get(), "No token request should be made with HTTP proxy URL");
+ }
+
+ @Test
+ public void testAksProxyWithMalformedUrl() {
+ String malformedUrl = "not-a-valid-url-at-all";
+
+ Configuration configuration = TestUtils
+ .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", malformedUrl)
+ .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString())
+ .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME)
+ .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID)
+ .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID)
+ .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString()));
+
+ Exception exception = assertThrows(Exception.class, () -> {
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID)
+ .clientId(TEST_CLIENT_ID)
+ .tokenFilePath(tokenFilePath.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .disableInstanceDiscovery()
+ .build();
+
+ TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+ credential.getTokenSync(request);
+ });
+
+ assertNotNull(exception, "Should throw exception when proxy URL is malformed");
+ assertEquals(0, tokenRequestCount.get(), "No token request should be made with malformed proxy URL");
+ }
+
+ @Test
+ public void testAksProxyUnreachable() {
+ String unreachableProxyUrl = "https://localhost:19999";
+
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", unreachableProxyUrl)
+ .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString())
+ .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME)
+ .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID)
+ .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID)
+ .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString()));
+
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID)
+ .clientId(TEST_CLIENT_ID)
+ .tokenFilePath(tokenFilePath.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .disableInstanceDiscovery()
+ .build();
+
+ TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+
+ Exception exception = assertThrows(Exception.class, () -> credential.getTokenSync(request));
+
+ assertNotNull(exception, "Should throw exception when proxy server is unreachable");
+ assertEquals(0, tokenRequestCount.get(), "No token request should succeed when proxy is unreachable");
+ }
+
+ @Test
+ public void testAksProxyWithEmptyTokenFile() throws Exception {
+ Path emptyTokenFile = tempDir.resolve("empty-token.jwt");
+ Files.write(emptyTokenFile, new byte[0]);
+
+ Configuration configuration = TestUtils
+ .createTestConfiguration(new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", serverBaseUrl)
+ .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString())
+ .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME)
+ .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID)
+ .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID)
+ .put("AZURE_FEDERATED_TOKEN_FILE", emptyTokenFile.toString()));
+
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID)
+ .clientId(TEST_CLIENT_ID)
+ .tokenFilePath(emptyTokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .disableInstanceDiscovery()
+ .build();
+
+ TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+
+ Exception exception = assertThrows(Exception.class, () -> credential.getTokenSync(request));
+
+ assertNotNull(exception, "Should throw exception when token file is empty");
+ assertEquals(0, tokenRequestCount.get(), "No token request should be made with empty token file");
+ }
+
+ @Test
+ public void testAksProxyWithUrlEncodedCharactersInPath() throws Exception {
+ String encodedPath = "/api%2Fv1/token";
+ String proxyUrlWithEncoding = serverBaseUrl + encodedPath;
+
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put("AZURE_KUBERNETES_TOKEN_PROXY", proxyUrlWithEncoding)
+ .put("AZURE_KUBERNETES_CA_FILE", caCertFilePath.toString())
+ .put("AZURE_KUBERNETES_SNI_NAME", AKS_SNI_NAME)
+ .put(Configuration.PROPERTY_AZURE_CLIENT_ID, TEST_CLIENT_ID)
+ .put(Configuration.PROPERTY_AZURE_TENANT_ID, TEST_TENANT_ID)
+ .put("AZURE_FEDERATED_TOKEN_FILE", tokenFilePath.toString()));
+
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId(TEST_TENANT_ID)
+ .clientId(TEST_CLIENT_ID)
+ .tokenFilePath(tokenFilePath.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .disableInstanceDiscovery()
+ .build();
+
+ TokenRequestContext request = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+
+ AccessToken token = credential.getTokenSync(request);
+
+ assertNotNull(token, "Token should not be null");
+ assertEquals(MOCK_ACCESS_TOKEN, token.getToken(), "Token value should match mock token");
+ assertTrue(token.getExpiresAt().isAfter(OffsetDateTime.now()), "Token should not be expired");
+ assertEquals(1, tokenRequestCount.get(), "Token request should be made once");
+ }
+
+ private X509Certificate generateCertificateWithAksSni(KeyPair keyPair, String aksSniName) throws Exception {
+ if (Security.getProvider("BC") == null) {
+ Security.addProvider(new BouncyCastleProvider());
+ }
+
+ long now = System.currentTimeMillis();
+ Date notBefore = new Date(now - 1000L * 60 * 60);
+ Date notAfter = new Date(now + 1000L * 60 * 60 * 24 * 365);
+ BigInteger serial = BigInteger.valueOf(now);
+
+ X500Name subject = new X500Name("CN=AKS-Proxy-Test");
+ SubjectPublicKeyInfo publicKeyInfo = SubjectPublicKeyInfo.getInstance(keyPair.getPublic().getEncoded());
+
+ X509v3CertificateBuilder certBuilder
+ = new X509v3CertificateBuilder(subject, serial, notBefore, notAfter, subject, publicKeyInfo);
+
+ GeneralName[] sans = new GeneralName[] {
+ new GeneralName(GeneralName.dNSName, "localhost"),
+ new GeneralName(GeneralName.dNSName, new org.bouncycastle.asn1.DERIA5String(aksSniName)) };
+ GeneralNames subjectAltNames = new GeneralNames(sans);
+ certBuilder.addExtension(Extension.subjectAlternativeName, false, subjectAltNames);
+
+ ContentSigner signer
+ = new JcaContentSignerBuilder("SHA256withRSA").setProvider("BC").build(keyPair.getPrivate());
+
+ X509CertificateHolder certHolder = certBuilder.build(signer);
+
+ CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
+ try (InputStream in = new ByteArrayInputStream(certHolder.getEncoded())) {
+ return (X509Certificate) certFactory.generateCertificate(in);
+ }
+ }
+
+ private void startNettyHttpsServer(PrivateKey privateKey, X509Certificate certificate) throws Exception {
+ SslContext sslContext
+ = SslContextBuilder.forServer(privateKey, certificate).sslProvider(SslProvider.OPENSSL).build();
+
+ bossGroup = new NioEventLoopGroup(1);
+ workerGroup = new NioEventLoopGroup();
+
+ ServerBootstrap bootstrap = new ServerBootstrap();
+ bootstrap.group(bossGroup, workerGroup)
+ .channel(NioServerSocketChannel.class)
+ .childHandler(new ChannelInitializer() {
+ @Override
+ protected void initChannel(SocketChannel ch) {
+ ChannelPipeline pipeline = ch.pipeline();
+
+ SslHandler sslHandler = sslContext.newHandler(ch.alloc());
+
+ sslHandler.handshakeFuture().addListener(future -> {
+ if (future.isSuccess()) {
+ SSLEngine engine = sslHandler.engine();
+ System.out.println("OpenSSL handshake successful.");
+ }
+ });
+
+ pipeline.addLast(sslHandler);
+ pipeline.addLast(new HttpServerCodec());
+ pipeline.addLast(new HttpObjectAggregator(65536));
+ pipeline.addLast(new HttpRequestHandler());
+ }
+ });
+
+ serverChannel = bootstrap.bind("localhost", 0).sync().channel();
+ InetSocketAddress socketAddress = (InetSocketAddress) serverChannel.localAddress();
+ int port = socketAddress.getPort();
+ serverBaseUrl = "https://localhost:" + port;
+
+ System.out.println("Netty HTTPS server (OpenSSL) started at: " + serverBaseUrl);
+ }
+
+ private class HttpRequestHandler extends SimpleChannelInboundHandler {
+ @Override
+ protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) {
+ String uri = request.uri();
+ String responseBody;
+
+ if (uri.contains("/.well-known/openid-configuration")) {
+ metadataRequestCount.incrementAndGet();
+ String base = serverBaseUrl + "/" + TEST_TENANT_ID;
+ responseBody = String
+ .format("{%n" + " \"token_endpoint\": \"%s/oauth2/v2.0/token\",%n" + " \"issuer\": \"%s/v2.0\",%n"
+ + " \"authorization_endpoint\": \"%s/oauth2/v2.0/authorize\"%n" + "}", base, base, base);
+ } else {
+ tokenRequestCount.incrementAndGet();
+ responseBody = String.format(
+ "{\"token_type\":\"Bearer\",\"expires_in\":3600,\"ext_expires_in\":3600,\"access_token\":\"%s\"}",
+ MOCK_ACCESS_TOKEN);
+ }
+
+ ByteBuf content = Unpooled.copiedBuffer(responseBody, CharsetUtil.UTF_8);
+ FullHttpResponse response
+ = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, content);
+ response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json");
+ response.headers().set(HttpHeaderNames.CONTENT_LENGTH, content.readableBytes());
+
+ ctx.writeAndFlush(response);
+ }
+
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ cause.printStackTrace();
+ ctx.close();
+ }
+ }
+
+ private String toPemFormat(X509Certificate certificate) throws Exception {
+ Base64.Encoder encoder = Base64.getMimeEncoder(64, "\n".getBytes(StandardCharsets.UTF_8));
+ byte[] encoded = encoder.encode(certificate.getEncoded());
+ return "-----BEGIN CERTIFICATE-----\n" + new String(encoded, StandardCharsets.UTF_8)
+ + "\n-----END CERTIFICATE-----\n";
+ }
+
+ private String createMockFederatedToken() {
+ String header = Base64.getUrlEncoder()
+ .withoutPadding()
+ .encodeToString("{\"alg\":\"RS256\",\"typ\":\"JWT\"}".getBytes(StandardCharsets.UTF_8));
+
+ long exp = System.currentTimeMillis() / 1000 + 3600;
+ String payload = Base64.getUrlEncoder()
+ .withoutPadding()
+ .encodeToString(String.format(
+ "{\"aud\":\"api://AzureADTokenExchange\",\"exp\":%d,\"iss\":\"kubernetes.io/serviceaccount\",\"sub\":\"system:serviceaccount:default:workload-identity-sa\"}",
+ exp).getBytes(StandardCharsets.UTF_8));
+
+ String signature
+ = Base64.getUrlEncoder().withoutPadding().encodeToString("mock-signature".getBytes(StandardCharsets.UTF_8));
+
+ return header + "." + payload + "." + signature;
+ }
+}
diff --git a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java
index d2404d9b302b..e52c3ef8735f 100644
--- a/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java
+++ b/sdk/identity/azure-identity/src/test/java/com/azure/identity/WorkloadIdentityCredentialTest.java
@@ -35,6 +35,10 @@
public class WorkloadIdentityCredentialTest {
private static final String CLIENT_ID = UUID.randomUUID().toString();
+ private static final String ENV_PROXY_URL = "AZURE_KUBERNETES_TOKEN_PROXY";
+ private static final String ENV_CA_FILE = "AZURE_KUBERNETES_CA_FILE";
+ private static final String ENV_CA_DATA = "AZURE_KUBERNETES_CA_DATA";
+ private static final String ENV_SNI_NAME = "AZURE_KUBERNETES_SNI_NAME";
@Test
public void testWorkloadIdentityFlow(@TempDir Path tempDir) throws IOException {
@@ -192,4 +196,342 @@ public void testFileReadingError(@TempDir Path tempDir) {
assertTrue(error.getCause() instanceof IOException); // Original IOException from Files.readAllBytes
}).verify();
}
+
+ @Test
+ public void testProxyEnabledWithProxyUrlGetsToken(@TempDir Path tempDir) throws IOException {
+ // setup
+ String endpoint = "https://localhost";
+ String token1 = "token1";
+ String proxyUrl = "https://token-proxy.example.com";
+
+ TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+ OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1);
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)
+ .put(ENV_PROXY_URL, proxyUrl));
+
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+
+ try (MockedConstruction identityClientMock
+ = mockConstruction(IdentityClient.class, (identityClient, context) -> {
+ when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty());
+ when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class)))
+ .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt));
+ })) {
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+
+ StepVerifier.create(credential.getToken(request1))
+ .expectNextMatches(token -> token1.equals(token.getToken())
+ && expiresAt.getSecond() == token.getExpiresAt().getSecond())
+ .verifyComplete();
+
+ assertNotNull(identityClientMock);
+ }
+ }
+
+ @Test
+ public void testProxyEnabledWithoutProxyUrlGetsToken(@TempDir Path tempDir) throws IOException {
+ // setup
+ String endpoint = "https://localhost";
+ String token1 = "token1";
+ TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+ OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1);
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint));
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+
+ try (MockedConstruction identityClientMock
+ = mockConstruction(IdentityClient.class, (identityClient, context) -> {
+ when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty());
+ when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class)))
+ .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt));
+ })) {
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+
+ StepVerifier.create(credential.getToken(request1))
+ .expectNextMatches(token -> token1.equals(token.getToken())
+ && expiresAt.getSecond() == token.getExpiresAt().getSecond())
+ .verifyComplete();
+
+ assertNotNull(identityClientMock);
+ }
+ }
+
+ @Test
+ public void testProxyEnabledInvalidProxyUrlSchemeFailure(@TempDir Path tempDir) throws IOException {
+ String endpoint = "https://localhost";
+
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)
+ .put(ENV_PROXY_URL, "http://not-https.example.com"));
+
+ Assertions.assertThrows(IllegalArgumentException.class, () -> {
+ new WorkloadIdentityCredentialBuilder().tenantId("tenant")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+ });
+ }
+
+ @Test
+ public void testProxyUrlWithQueryFailure(@TempDir Path tempDir) throws IOException {
+ String endpoint = "https://login.microsoftonline.com";
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)
+ .put(ENV_PROXY_URL, "https://proxy.example.com?x=y"));
+
+ Assertions.assertThrows(IllegalArgumentException.class, () -> {
+ new WorkloadIdentityCredentialBuilder().tenantId("tenant")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+ });
+ }
+
+ @Test
+ public void testProxyUrlWithFragmentFailure(@TempDir Path tempDir) throws IOException {
+ String endpoint = "https://login.microsoftonline.com";
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)
+ .put(ENV_PROXY_URL, "https://proxy.example.com#frag"));
+
+ Assertions.assertThrows(IllegalArgumentException.class, () -> {
+ new WorkloadIdentityCredentialBuilder().tenantId("tenant")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+ });
+ }
+
+ @Test
+ public void testProxyUrlWithUserInfoFailure(@TempDir Path tempDir) throws IOException {
+ String endpoint = "https://login.microsoftonline.com";
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)
+ .put(ENV_PROXY_URL, "https://user:pass@proxy.example.com"));
+
+ Assertions.assertThrows(IllegalArgumentException.class, () -> {
+ new WorkloadIdentityCredentialBuilder().tenantId("tenant")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+ });
+ }
+
+ @Test
+ public void testCaFileAndCaDataPresentFailure(@TempDir Path tempDir) throws IOException {
+ String endpoint = "https://login.microsoftonline.com";
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+
+ Path caFile = tempDir.resolve("ca.crt");
+ Files.write(caFile,
+ "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----\n".getBytes(StandardCharsets.UTF_8));
+
+ String caData = "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----";
+
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)
+ .put(ENV_PROXY_URL, "https://proxy.example.com")
+ .put(ENV_CA_FILE, caFile.toString())
+ .put(ENV_CA_DATA, caData));
+
+ Assertions.assertThrows(IllegalArgumentException.class, () -> {
+ new WorkloadIdentityCredentialBuilder().tenantId("tenant")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+ });
+ }
+
+ @Test
+ public void testProxyEnabledWithProxyUrlGetsTokenSync(@TempDir Path tempDir) throws IOException {
+ // setup
+ String endpoint = "https://localhost";
+ String token1 = "token1";
+ String proxyUrl = "https://token-proxy.example.com";
+
+ TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+ OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1);
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)
+ .put(ENV_PROXY_URL, proxyUrl));
+
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+
+ try (MockedConstruction identitySyncClientMock
+ = mockConstruction(IdentitySyncClient.class, (identityClient, context) -> {
+ when(identityClient.authenticateWithConfidentialClientCache(any()))
+ .thenThrow(new IllegalStateException("Test"));
+ when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class)))
+ .thenReturn(TestUtils.getMockAccessTokenSync(token1, expiresAt));
+ })) {
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+
+ AccessToken token = credential.getTokenSync(request1);
+
+ assertTrue(token1.equals(token.getToken()));
+ assertTrue(expiresAt.getSecond() == token.getExpiresAt().getSecond());
+ assertNotNull(identitySyncClientMock);
+ }
+ }
+
+ @Test
+ public void testProxyUrlWithCaDataAcquiresToken(@TempDir Path tempDir) throws IOException {
+ String endpoint = "https://login.microsoftonline.com";
+ String token1 = "token-ca-data";
+ TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+ OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1);
+
+ String caData = "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----";
+
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)
+ .put(ENV_PROXY_URL, "https://token-proxy.example.com")
+ .put(ENV_CA_DATA, caData));
+
+ try (MockedConstruction mocked
+ = mockConstruction(IdentityClient.class, (identityClient, context) -> {
+ when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty());
+ when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class)))
+ .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt));
+ })) {
+ WorkloadIdentityCredential cred = new WorkloadIdentityCredentialBuilder().tenantId("tenant")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+
+ StepVerifier.create(cred.getToken(request1))
+ .expectNextMatches(token -> token1.equals(token.getToken())
+ && token.getExpiresAt().getSecond() == expiresAt.getSecond())
+ .verifyComplete();
+
+ assertNotNull(mocked);
+ }
+ }
+
+ @Test
+ public void testProxyUrlWithCaFileGetsToken(@TempDir Path tempDir) throws IOException {
+ String endpoint = "https://login.microsoftonline.com";
+ String token1 = "tok-ca-file";
+ TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+ OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1);
+
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+ Path caFile = tempDir.resolve("ca.crt");
+ Files.write(caFile,
+ "-----BEGIN CERTIFICATE-----\nMIIB...==\n-----END CERTIFICATE-----\n".getBytes(StandardCharsets.UTF_8));
+
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)
+ .put(ENV_PROXY_URL, "https://token-proxy.example.com")
+ .put(ENV_CA_FILE, caFile.toString()));
+
+ try (MockedConstruction mocked
+ = mockConstruction(IdentityClient.class, (identityClient, context) -> {
+ when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty());
+ when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class)))
+ .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt));
+ })) {
+ WorkloadIdentityCredential cred = new WorkloadIdentityCredentialBuilder().tenantId("tenant")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+
+ StepVerifier.create(cred.getToken(request1))
+ .expectNextMatches(token -> token1.equals(token.getToken())
+ && token.getExpiresAt().getSecond() == expiresAt.getSecond())
+ .verifyComplete();
+
+ assertNotNull(mocked);
+ }
+ }
+
+ @Test
+ public void testProxyEnabledWithSniNameGetsToken(@TempDir Path tempDir) throws IOException {
+ // setup
+ String endpoint = "https://localhost";
+ String token1 = "token1";
+ String proxyUrl = "https://token-proxy.example.com";
+ String sniName = "615f3b8ad7eb011a09ed3b762e404de43ebc7ade0802a34c9fd322b688c3a655.ests.aks";
+
+ TokenRequestContext request1 = new TokenRequestContext().addScopes("https://management.azure.com/.default");
+ OffsetDateTime expiresAt = OffsetDateTime.now(ZoneOffset.UTC).plusHours(1);
+ Configuration configuration = TestUtils.createTestConfiguration(
+ new TestConfigurationSource().put(Configuration.PROPERTY_AZURE_AUTHORITY_HOST, endpoint)
+ .put(ENV_PROXY_URL, proxyUrl)
+ .put(ENV_SNI_NAME, sniName));
+
+ Path tokenFile = tempDir.resolve("token.txt");
+ Files.write(tokenFile, "dummy-token".getBytes(StandardCharsets.UTF_8));
+
+ try (MockedConstruction identityClientMock
+ = mockConstruction(IdentityClient.class, (identityClient, context) -> {
+ when(identityClient.authenticateWithConfidentialClientCache(any())).thenReturn(Mono.empty());
+ when(identityClient.authenticateWithConfidentialClient(any(TokenRequestContext.class)))
+ .thenReturn(TestUtils.getMockAccessToken(token1, expiresAt));
+ })) {
+ WorkloadIdentityCredential credential = new WorkloadIdentityCredentialBuilder().tenantId("dummy-tenantid")
+ .clientId(CLIENT_ID)
+ .tokenFilePath(tokenFile.toString())
+ .configuration(configuration)
+ .enableKubernetesTokenProxy()
+ .build();
+
+ StepVerifier.create(credential.getToken(request1))
+ .expectNextMatches(token -> token1.equals(token.getToken())
+ && expiresAt.getSecond() == token.getExpiresAt().getSecond())
+ .verifyComplete();
+
+ assertNotNull(identityClientMock);
+ }
+ }
+
}