Skip to content

Commit fb990f2

Browse files
authored
Fix for first run cache miss for interactive credentials (Azure#38449)
1 parent 45a08cf commit fb990f2

File tree

5 files changed

+34
-62
lines changed

5 files changed

+34
-62
lines changed

sdk/identity/Azure.Identity/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Bugs Fixed
66

77
- `ManagedIdentityCredential` will fall through to the next credential in the chain in the case that Docker Desktop returns a 403 response when attempting to access the IMDS endpoint. [#38218](https://github.com/Azure/azure-sdk-for-net/issues/38218)
8+
- Fixed an issue where interactive credentials would still prompt on the first GetToken request even when the cache is populated and an AuthenticationRecord is provided. [#38431](https://github.com/Azure/azure-sdk-for-net/issues/38431)
89

910
## 1.10.0 (2023-08-14)
1011

sdk/identity/Azure.Identity/src/Credentials/AuthorizationCodeCredential.cs

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ public class AuthorizationCodeCredential : TokenCredential
2222
private readonly string _clientId;
2323
private readonly CredentialPipeline _pipeline;
2424
private AuthenticationRecord _record;
25-
private bool _isCaeEnabledRequestCached = false;
26-
private bool _isCaeDisabledRequestCached = false;
2725
internal MsalConfidentialClient Client { get; }
2826
private readonly string _redirectUri;
2927
private readonly string _tenantId;
@@ -138,33 +136,15 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC
138136
{
139137
using CredentialDiagnosticScope scope = _pipeline.StartGetTokenScope($"{nameof(AuthorizationCodeCredential)}.{nameof(GetToken)}", requestContext);
140138

139+
AccessToken token;
140+
string tenantId = null;
141141
try
142142
{
143-
AccessToken token;
144-
var tenantId = TenantIdResolver.Resolve(_tenantId, requestContext, AdditionallyAllowedTenantIds);
145-
var isCachePopulated = _record switch
146-
{
147-
not null when requestContext.IsCaeEnabled && _isCaeEnabledRequestCached => true,
148-
not null when !requestContext.IsCaeEnabled && _isCaeDisabledRequestCached => true,
149-
_ => false
150-
};
143+
tenantId = TenantIdResolver.Resolve(_tenantId, requestContext, AdditionallyAllowedTenantIds);
151144

152-
if (!isCachePopulated)
145+
if (_record is null)
153146
{
154-
AuthenticationResult result = await Client
155-
.AcquireTokenByAuthorizationCodeAsync(requestContext.Scopes, _authCode, tenantId, _redirectUri, requestContext.IsCaeEnabled, async, cancellationToken)
156-
.ConfigureAwait(false);
157-
_record = new AuthenticationRecord(result, _clientId);
158-
if (requestContext.IsCaeEnabled)
159-
{
160-
_isCaeEnabledRequestCached = true;
161-
}
162-
else
163-
{
164-
_isCaeDisabledRequestCached = true;
165-
}
166-
167-
token = new AccessToken(result.AccessToken, result.ExpiresOn);
147+
token = await AcquireTokenWithCode(async, requestContext, token, tenantId, cancellationToken).ConfigureAwait(false);
168148
}
169149
else
170150
{
@@ -176,10 +156,35 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC
176156

177157
return scope.Succeeded(token);
178158
}
159+
catch (MsalUiRequiredException)
160+
{
161+
// This occurs when we have an auth record but the cae or ncae cache entry is missing
162+
// fall through to the acquire call below
163+
}
179164
catch (Exception e)
180165
{
181166
throw scope.FailWrapAndThrow(e);
182167
}
168+
169+
try
170+
{
171+
token = await AcquireTokenWithCode(async, requestContext, token, tenantId, cancellationToken).ConfigureAwait(false);
172+
return scope.Succeeded(token);
173+
}
174+
catch (Exception e)
175+
{
176+
throw scope.FailWrapAndThrow(e);
177+
}
178+
}
179+
180+
private async Task<AccessToken> AcquireTokenWithCode(bool async, TokenRequestContext requestContext, AccessToken token, string tenantId, CancellationToken cancellationToken)
181+
{
182+
AuthenticationResult result = await Client
183+
.AcquireTokenByAuthorizationCodeAsync(requestContext.Scopes, _authCode, tenantId, _redirectUri, requestContext.IsCaeEnabled, async, cancellationToken)
184+
.ConfigureAwait(false);
185+
_record = new AuthenticationRecord(result, _clientId);
186+
token = new AccessToken(result.AccessToken, result.ExpiresOn);
187+
return token;
183188
}
184189
}
185190
}

sdk/identity/Azure.Identity/src/Credentials/DeviceCodeCredential.cs

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ public class DeviceCodeCredential : TokenCredential
2323
internal string ClientId { get; }
2424
internal bool DisableAutomaticAuthentication { get; }
2525
internal AuthenticationRecord Record { get; private set; }
26-
private bool _isCaeEnabledRequestCached = false;
27-
private bool _isCaeDisabledRequestCached = false;
2826
internal Func<DeviceCodeInfo, CancellationToken, Task> DeviceCodeCallback { get; }
2927
internal CredentialPipeline Pipeline { get; }
3028
internal string DefaultScope { get; }
@@ -211,16 +209,9 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC
211209
try
212210
{
213211
Exception inner = null;
214-
215212
var tenantId = TenantIdResolver.Resolve(_tenantId, requestContext, AdditionallyAllowedTenantIds);
216-
var isCachePopulated = Record switch
217-
{
218-
not null when requestContext.IsCaeEnabled && _isCaeEnabledRequestCached => true,
219-
not null when !requestContext.IsCaeEnabled && _isCaeDisabledRequestCached => true,
220-
_ => false
221-
};
222213

223-
if (isCachePopulated)
214+
if (Record is not null)
224215
{
225216
try
226217
{
@@ -255,15 +246,6 @@ private async Task<AccessToken> GetTokenViaDeviceCodeAsync(TokenRequestContext c
255246
.ConfigureAwait(false);
256247

257248
Record = new AuthenticationRecord(result, ClientId);
258-
if (context.IsCaeEnabled)
259-
{
260-
_isCaeEnabledRequestCached = true;
261-
}
262-
else
263-
{
264-
_isCaeDisabledRequestCached = true;
265-
}
266-
267249
return new AccessToken(result.AccessToken, result.ExpiresOn);
268250
}
269251

sdk/identity/Azure.Identity/src/Credentials/InteractiveBrowserCredential.cs

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ public class InteractiveBrowserCredential : TokenCredential
2626
internal CredentialPipeline Pipeline { get; }
2727
internal bool DisableAutomaticAuthentication { get; }
2828
internal AuthenticationRecord Record { get; private set; }
29-
internal bool _isCaeEnabledRequestCached = false;
30-
internal bool _isCaeDisabledRequestCached = false;
3129
internal string DefaultScope { get; }
3230

3331
private const string AuthenticationRequiredMessage = "Interactive authentication is needed to acquire token. Call Authenticate to interactively authenticate.";
@@ -197,13 +195,7 @@ private async ValueTask<AccessToken> GetTokenImplAsync(bool async, TokenRequestC
197195
Exception inner = null;
198196

199197
var tenantId = TenantIdResolver.Resolve(TenantId ?? Record?.TenantId, requestContext, AdditionallyAllowedTenantIds);
200-
var isCachePopulated = Record switch
201-
{
202-
not null when requestContext.IsCaeEnabled && _isCaeEnabledRequestCached => true,
203-
not null when !requestContext.IsCaeEnabled && _isCaeDisabledRequestCached => true,
204-
_ => false
205-
};
206-
if (isCachePopulated)
198+
if (Record is not null)
207199
{
208200
try
209201
{
@@ -246,14 +238,6 @@ private async Task<AccessToken> GetTokenViaBrowserLoginAsync(TokenRequestContext
246238
.ConfigureAwait(false);
247239

248240
Record = new AuthenticationRecord(result, ClientId);
249-
if (context.IsCaeEnabled)
250-
{
251-
_isCaeEnabledRequestCached = true;
252-
}
253-
else
254-
{
255-
_isCaeDisabledRequestCached = true;
256-
}
257241
return new AccessToken(result.AccessToken, result.ExpiresOn);
258242
}
259243
}

sdk/identity/Azure.Identity/tests/InteractiveBrowserCredentialTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public override TokenCredential GetTokenCredential(CommonCredentialTestConfig co
4040
IsUnsafeSupportLoggingEnabled = config.IsUnsafeSupportLoggingEnabled,
4141
};
4242
var pipeline = CredentialPipeline.GetInstance(options);
43-
return InstrumentClient(new InteractiveBrowserCredential(config.TenantId, ClientId, options, pipeline, null) { _isCaeDisabledRequestCached = true, _isCaeEnabledRequestCached = true });
43+
return InstrumentClient(new InteractiveBrowserCredential(config.TenantId, ClientId, options, pipeline, null));
4444
}
4545

4646
[Test]

0 commit comments

Comments
 (0)