Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions src/api/coderApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
type AxiosInstance,
type AxiosHeaders,
type AxiosResponseTransformer,
isAxiosError,
} from "axios";
import { Api } from "coder/site/src/api/api";
import {
Expand Down Expand Up @@ -30,6 +31,12 @@ import {
HttpClientLogLevel,
} from "../logging/types";
import { sizeOf } from "../logging/utils";
import {
parseOAuthError,
requiresReAuthentication,
isNetworkError,
} from "../oauth/errors";
import { type OAuthSessionManager } from "../oauth/sessionManager";
import { type UnidirectionalStream } from "../websocket/eventStreamConnection";
import {
OneWayWebSocket,
Expand Down Expand Up @@ -58,14 +65,15 @@ export class CoderApi extends Api {
baseUrl: string,
token: string | undefined,
output: Logger,
oauthSessionManager?: OAuthSessionManager,
): CoderApi {
const client = new CoderApi(output);
client.setHost(baseUrl);
if (token) {
client.setSessionToken(token);
}

setupInterceptors(client, baseUrl, output);
setupInterceptors(client, baseUrl, output, oauthSessionManager);
return client;
}

Expand Down Expand Up @@ -302,6 +310,7 @@ function setupInterceptors(
client: CoderApi,
baseUrl: string,
output: Logger,
oauthSessionManager?: OAuthSessionManager,
): void {
addLoggingInterceptors(client.getAxiosInstance(), output);

Expand Down Expand Up @@ -334,6 +343,11 @@ function setupInterceptors(
throw await CertificateError.maybeWrap(err, baseUrl, output);
},
);

// OAuth token refresh interceptors
if (oauthSessionManager) {
addOAuthInterceptors(client, output, oauthSessionManager);
}
}

function addLoggingInterceptors(client: AxiosInstance, logger: Logger) {
Expand Down Expand Up @@ -363,7 +377,7 @@ function addLoggingInterceptors(client: AxiosInstance, logger: Logger) {
},
(error: unknown) => {
logError(logger, error, getLogLevel());
return Promise.reject(error);
throw error;
},
);

Expand All @@ -374,7 +388,80 @@ function addLoggingInterceptors(client: AxiosInstance, logger: Logger) {
},
(error: unknown) => {
logError(logger, error, getLogLevel());
return Promise.reject(error);
throw error;
},
);
}

/**
* Add OAuth token refresh interceptors.
* Success interceptor: proactively refreshes token when approaching expiry.
* Error interceptor: reactively refreshes token on 401/403 responses.
*/
function addOAuthInterceptors(
client: CoderApi,
logger: Logger,
oauthSessionManager: OAuthSessionManager,
) {
client.getAxiosInstance().interceptors.response.use(
// Success response interceptor: proactive token refresh
(response) => {
if (oauthSessionManager.shouldRefreshToken()) {
logger.debug(
"Token approaching expiry, triggering proactive refresh in background",
);

// Fire-and-forget: don't await, don't block response
oauthSessionManager.refreshToken().catch((error) => {
logger.warn("Background token refresh failed:", error);
});
}

return response;
},
// Error response interceptor: reactive token refresh on 401/403
async (error: unknown) => {
if (!isAxiosError(error)) {
throw error;
}

const status = error.response?.status;
if (status !== 401 && status !== 403) {
throw error;
}

if (!oauthSessionManager.isLoggedInWithOAuth()) {
throw error;
}

logger.info(`Received ${status} response, attempting token refresh`);

try {
const newTokens = await oauthSessionManager.refreshToken();
client.setSessionToken(newTokens.access_token);

logger.info("Token refresh successful, updated session token");
} catch (refreshError) {
logger.error("Token refresh failed:", refreshError);

const oauthError = parseOAuthError(refreshError);
if (oauthError && requiresReAuthentication(oauthError)) {
logger.error(
`OAuth error requires re-authentication: ${oauthError.errorCode}`,
);

oauthSessionManager
.showReAuthenticationModal(oauthError)
.catch((err) => {
logger.error("Failed to show re-auth modal:", err);
});
} else if (isNetworkError(refreshError)) {
logger.warn(
"Token refresh failed due to network error, will retry later",
);
}
}
throw error;
},
);
}
Expand Down
138 changes: 113 additions & 25 deletions src/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ import { type SecretsManager } from "./core/secretsManager";
import { CertificateError } from "./error";
import { getGlobalFlags } from "./globalFlags";
import { type Logger } from "./logging/logger";
import { OAuthMetadataClient } from "./oauth/metadataClient";
import { type OAuthSessionManager } from "./oauth/sessionManager";
import { escapeCommandArg, toRemoteAuthority, toSafeHost } from "./util";
import {
AgentTreeItem,
type OpenableTreeItem,
WorkspaceTreeItem,
} from "./workspace/workspacesProvider";

type AuthMethod = "oauth" | "legacy";

export class Commands {
private readonly vscodeProposed: typeof vscode;
private readonly logger: Logger;
Expand All @@ -48,6 +52,7 @@ export class Commands {
public constructor(
serviceContainer: ServiceContainer,
private readonly restClient: Api,
private readonly oauthSessionManager: OAuthSessionManager,
) {
this.vscodeProposed = serviceContainer.getVsCodeProposed();
this.logger = serviceContainer.getLogger();
Expand Down Expand Up @@ -205,10 +210,10 @@ export class Commands {
// It is possible that we are trying to log into an old-style host, in which
// case we want to write with the provided blank label instead of generating
// a host label.
const label = args?.label === undefined ? toSafeHost(url) : args.label;

const label = args?.label ?? toSafeHost(url);
// Try to get a token from the user, if we need one, and their user.
const autoLogin = args?.autoLogin === true;

const res = await this.maybeAskToken(url, args?.token, autoLogin);
if (!res) {
return; // The user aborted, or unable to auth.
Expand All @@ -228,7 +233,7 @@ export class Commands {

// These contexts control various menu items and the sidebar.
this.contextManager.set("coder.authenticated", true);
if (res.user.roles.find((role) => role.name === "owner")) {
if (res.user.roles.some((role) => role.name === "owner")) {
this.contextManager.set("coder.isOwner", true);
}

Expand Down Expand Up @@ -291,6 +296,68 @@ export class Commands {
}
}

// Check if server supports OAuth
const supportsOAuth = await this.checkOAuthSupport(client);

let choice: AuthMethod | undefined = "legacy";
if (supportsOAuth) {
choice = await this.askAuthMethod();
}

if (choice === "oauth") {
return this.loginWithOAuth(client);
} else if (choice === "legacy") {
const initialToken =
token || (await this.secretsManager.getSessionToken());
return this.loginWithToken(client, initialToken);
}

return null; // User aborted.
}

private async checkOAuthSupport(client: CoderApi): Promise<boolean> {
const metadataClient = new OAuthMetadataClient(
client.getAxiosInstance(),
this.logger,
);
return metadataClient.checkOAuthSupport();
}

/**
* Ask user to choose between OAuth and legacy API token authentication.
*/
private async askAuthMethod(): Promise<AuthMethod | undefined> {
const choice = await vscode.window.showQuickPick(
[
{
label: "$(key) OAuth (Recommended)",
detail: "Secure authentication with automatic token refresh",
value: "oauth" as const,
},
{
label: "$(lock) API Token",
detail: "Use a manually created API key",
value: "legacy" as const,
},
],
{
title: "Choose Authentication Method",
placeHolder: "How would you like to authenticate?",
ignoreFocusOut: true,
},
);

return choice?.value;
}

private async loginWithToken(
client: CoderApi,
initialToken: string | undefined,
): Promise<{ user: User; token: string } | null> {
const url = client.getAxiosInstance().defaults.baseURL;
if (!url) {
throw new Error("No base URL set on REST client");
}
// This prompt is for convenience; do not error if they close it since
// they may already have a token or already have the page opened.
await vscode.env.openExternal(vscode.Uri.parse(`${url}/cli-auth`));
Expand All @@ -303,7 +370,7 @@ export class Commands {
title: "Coder API Key",
password: true,
placeHolder: "Paste your API key.",
value: token || (await this.secretsManager.getSessionToken()),
value: initialToken,
ignoreFocusOut: true,
validateInput: async (value) => {
if (!value) {
Expand Down Expand Up @@ -335,12 +402,40 @@ export class Commands {
},
});

if (validatedToken && user) {
return { token: validatedToken, user };
if (user === undefined || validatedToken === undefined) {
return null;
}

// User aborted.
return null;
return { user, token: validatedToken };
}

/**
* Authenticate using OAuth flow.
* Returns the access token and authenticated user, or null if failed/cancelled.
*/
private async loginWithOAuth(
client: CoderApi,
): Promise<{ user: User; token: string } | null> {
try {
this.logger.info("Starting OAuth authentication");

const tokenResponse = await this.oauthSessionManager.login(client);

// Validate token by fetching user
client.setSessionToken(tokenResponse.access_token);
const user = await client.getAuthenticatedUser();

return {
token: tokenResponse.access_token,
user,
};
} catch (error) {
this.logger.error("OAuth authentication failed:", error);
vscode.window.showErrorMessage(
`OAuth authentication failed: ${getErrorMessage(error, "Unknown error")}`,
);
return null;
}
}

/**
Expand Down Expand Up @@ -377,6 +472,7 @@ export class Commands {
// Sanity check; command should not be available if no url.
throw new Error("You are not logged in");
}

await this.forceLogout();
}

Expand All @@ -385,6 +481,12 @@ export class Commands {
return;
}
this.logger.info("Logging out");

// Fire and forget
this.oauthSessionManager.logout().catch((error) => {
this.logger.warn("OAuth logout failed, continuing with cleanup:", error);
});

// Clear from the REST client. An empty url will indicate to other parts of
// the code that we are logged out.
this.restClient.setHost("");
Expand Down Expand Up @@ -501,7 +603,7 @@ export class Commands {
true,
);
} else {
throw new Error("Unable to open unknown sidebar item");
throw new TypeError("Unable to open unknown sidebar item");
}
} else {
// If there is no tree item, then the user manually ran this command.
Expand Down Expand Up @@ -547,27 +649,14 @@ export class Commands {
configDir,
);
terminal.sendText(
`${escapeCommandArg(binary)}${` ${globalFlags.join(" ")}`} ssh ${app.workspace_name}`,
`${escapeCommandArg(binary)} ${globalFlags.join(" ")} ssh ${app.workspace_name}`,
);
await new Promise((resolve) => setTimeout(resolve, 5000));
terminal.sendText(app.command ?? "");
terminal.show(false);
},
);
}
// Check if app has a URL to open
if (app.url) {
return vscode.window.withProgress(
{
location: vscode.ProgressLocation.Notification,
title: `Opening ${app.name || "application"} in browser...`,
cancellable: false,
},
async () => {
await vscode.env.openExternal(vscode.Uri.parse(app.url!));
},
);
}

// If no URL or command, show information about the app status
vscode.window.showInformationMessage(`${app.name}`, {
Expand Down Expand Up @@ -646,7 +735,7 @@ export class Commands {
workspaceAgent,
);

const hostPath = localWorkspaceFolder ? localWorkspaceFolder : undefined;
const hostPath = localWorkspaceFolder || undefined;
const configFile =
hostPath && localConfigFile
? {
Expand Down Expand Up @@ -748,7 +837,6 @@ export class Commands {
if (ex instanceof CertificateError) {
ex.showNotification();
}
return;
});
});
quickPick.show();
Expand Down
Loading