Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions dev-auth/oidc-provider.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ const configuration = {
"http://localhost:3003/api/auth/oauth2/callback/oidc",
"http://localhost:3000/api/auth/oauth2/callback/okta",
],
// Post-logout redirect URIs for RP-Initiated Logout
post_logout_redirect_uris: [
"http://localhost:3000/signin",
"http://localhost:3001/signin",
"http://localhost:3002/signin",
"http://localhost:3003/signin",
],
response_types: ["code"],
grant_types: ["authorization_code", "refresh_token"],
token_endpoint_auth_method: "client_secret_post",
Expand Down Expand Up @@ -77,6 +84,19 @@ const configuration = {
// Disable built-in dev interactions in favor of our custom auto-login
// and auto-consent middleware used for local testing.
devInteractions: { enabled: false },
// Enable RP-Initiated Logout for sign-out flow
rpInitiatedLogout: {
enabled: true,
// Auto-confirm logout for dev (skip confirmation page)
logoutSource: async (ctx, form) => {
ctx.body = `<!DOCTYPE html>
<html><head><title>Logging out...</title></head>
<body>
${form}
<script>document.forms[0].submit();</script>
</body></html>`;
},
},
},
// Explicitly declare supported scopes, including offline_access for refresh tokens
// Scopes supported by the dev provider (app requests offline_access in dev and prod)
Expand Down
43 changes: 31 additions & 12 deletions src/app/api/auth/refresh-token/route.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import { cookies } from "next/headers";
import { cookies, headers } from "next/headers";
import type { NextRequest } from "next/server";
import { NextResponse } from "next/server";
import { refreshAccessToken } from "@/lib/auth/auth";
import { BETTER_AUTH_SECRET, COOKIE_NAME } from "@/lib/auth/constants";
import { auth, refreshAccessToken } from "@/lib/auth/auth";
import {
BETTER_AUTH_SECRET,
OIDC_TOKEN_COOKIE_NAME,
} from "@/lib/auth/constants";
import type { OidcTokenData } from "@/lib/auth/types";
import { decrypt } from "@/lib/auth/utils";

Expand All @@ -13,15 +16,30 @@ import { decrypt } from "@/lib/auth/utils";
*/
export async function POST(request: NextRequest) {
try {
// Check if Better Auth session exists before attempting token refresh
const session = await auth.api.getSession({
headers: await headers(),
});

if (!session?.user?.id) {
// No active session - skip token refresh (user is logged out)
return NextResponse.json({ error: "No active session" }, { status: 401 });
}

const body = await request.json();
const { userId } = body;

if (!userId) {
return NextResponse.json({ error: "Missing userId" }, { status: 400 });
}

// Verify userId matches the session
if (userId !== session.user.id) {
return NextResponse.json({ error: "User ID mismatch" }, { status: 401 });
}

const cookieStore = await cookies();
const encryptedCookie = cookieStore.get(COOKIE_NAME);
const encryptedCookie = cookieStore.get(OIDC_TOKEN_COOKIE_NAME);

if (!encryptedCookie?.value) {
return NextResponse.json({ error: "No token found" }, { status: 401 });
Expand All @@ -32,32 +50,33 @@ export async function POST(request: NextRequest) {
tokenData = await decrypt(encryptedCookie.value, BETTER_AUTH_SECRET);
} catch (error) {
console.error("[Refresh API] Token decryption failed:", error);
cookieStore.delete(COOKIE_NAME);
cookieStore.delete(OIDC_TOKEN_COOKIE_NAME);
return NextResponse.json({ error: "Invalid token" }, { status: 401 });
}

if (tokenData.userId !== userId) {
console.error("[Refresh API] Token userId mismatch");
cookieStore.delete(COOKIE_NAME);
cookieStore.delete(OIDC_TOKEN_COOKIE_NAME);
return NextResponse.json({ error: "Invalid token" }, { status: 401 });
}

if (!tokenData.refreshToken) {
console.error("[Refresh API] No refresh token available");
cookieStore.delete(COOKIE_NAME);
cookieStore.delete(OIDC_TOKEN_COOKIE_NAME);
return NextResponse.json({ error: "No refresh token" }, { status: 401 });
}

// Call refreshAccessToken which will save the new token in the cookie
const refreshedData = await refreshAccessToken(
tokenData.refreshToken,
const refreshedData = await refreshAccessToken({
refreshToken: tokenData.refreshToken,
refreshTokenExpiresAt: tokenData.refreshTokenExpiresAt,
userId,
tokenData.refreshTokenExpiresAt,
);
idToken: tokenData.idToken,
});

if (!refreshedData) {
console.error("[Refresh API] Token refresh failed");
cookieStore.delete(COOKIE_NAME);
cookieStore.delete(OIDC_TOKEN_COOKIE_NAME);
return NextResponse.json(
{ error: "[Refresh API] Refresh failed" },
{ status: 401 },
Expand Down
11 changes: 0 additions & 11 deletions src/app/catalog/page.tsx
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
import { headers } from "next/headers";
import { redirect } from "next/navigation";
import { auth } from "@/lib/auth/auth";
import { getServers } from "./actions";
import { ServersWrapper } from "./components/servers-wrapper";

export default async function CatalogPage() {
const session = await auth.api.getSession({
headers: await headers(),
});

if (!session) {
redirect("/signin");
}

const servers = await getServers();

return <ServersWrapper servers={servers} />;
Expand Down
2 changes: 1 addition & 1 deletion src/app/signin/signin-button.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export function SignInButton({ providerId }: { providerId: string }) {
return (
<Button
onClick={handleOIDCSignIn}
className="w-full h-9 gap-2"
className="w-full h-9 gap-2 cursor-pointer"
size="default"
disabled={isLoading}
>
Expand Down
133 changes: 133 additions & 0 deletions src/lib/auth/__tests__/auth-client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { clearRecordedRequests, server } from "@/mocks/node";
import * as actions from "../actions";

// Remove global mock of auth-client from vitest.setup.ts
vi.unmock("@/lib/auth/auth-client");

// Hoist mocks
const mockAuthClientSignOut = vi.hoisted(() => vi.fn());
const mockLocationReplace = vi.hoisted(() => vi.fn());

// Mock Better Auth client
vi.mock("better-auth/client/plugins", () => ({
genericOAuthClient: vi.fn(() => ({})),
}));

vi.mock("better-auth/react", () => ({
createAuthClient: vi.fn(() => ({
signIn: vi.fn(),
useSession: vi.fn(),
signOut: mockAuthClientSignOut,
})),
}));

// Mock window.location globally
Object.defineProperty(globalThis, "window", {
value: {
location: {
replace: mockLocationReplace,
},
},
writable: true,
configurable: true,
});

describe("signOut", () => {
beforeEach(() => {
vi.clearAllMocks();
clearRecordedRequests();
});

afterEach(() => {
vi.restoreAllMocks();
server.resetHandlers();
});

it("calls getOidcSignOutUrl, clearOidcTokenAction, redirect, and authClient.signOut", async () => {
const oidcLogoutUrl = "https://okta.example.com/logout?id_token_hint=xxx";

// Spy on server actions
const getOidcSignOutUrlSpy = vi
.spyOn(actions, "getOidcSignOutUrl")
.mockResolvedValue(oidcLogoutUrl);
const clearOidcTokenActionSpy = vi
.spyOn(actions, "clearOidcTokenAction")
.mockResolvedValue(undefined);
mockAuthClientSignOut.mockResolvedValue(undefined);

const { signOut } = await import("../auth-client");

await signOut();

// Verify all functions were called
expect(getOidcSignOutUrlSpy).toHaveBeenCalledTimes(1);
expect(clearOidcTokenActionSpy).toHaveBeenCalledTimes(1);
expect(mockLocationReplace).toHaveBeenCalledWith(oidcLogoutUrl);
expect(mockAuthClientSignOut).toHaveBeenCalledTimes(1);
});

it("calls functions in correct order", async () => {
const callOrder: string[] = [];

vi.spyOn(actions, "getOidcSignOutUrl").mockImplementation(async () => {
callOrder.push("getOidcSignOutUrl");
return "https://okta.example.com/logout";
});

vi.spyOn(actions, "clearOidcTokenAction").mockImplementation(async () => {
callOrder.push("clearOidcTokenAction");
});

mockLocationReplace.mockImplementation(() => {
callOrder.push("window.location.replace");
});

mockAuthClientSignOut.mockImplementation(async () => {
callOrder.push("authClient.signOut");
});

const { signOut } = await import("../auth-client");

await signOut();

expect(callOrder).toEqual([
"getOidcSignOutUrl",
"clearOidcTokenAction",
"window.location.replace",
"authClient.signOut",
]);
});

it("redirects to /signin on error", async () => {
const consoleErrorSpy = vi
.spyOn(console, "error")
.mockImplementation(() => {});

vi.spyOn(actions, "getOidcSignOutUrl").mockRejectedValue(
new Error("Network error"),
);

const { signOut } = await import("../auth-client");

await signOut();

expect(mockLocationReplace).toHaveBeenCalledWith("/signin");
expect(consoleErrorSpy).toHaveBeenCalledWith(
"[Auth] Sign out error:",
expect.any(Error),
);
});

it("uses /signin as fallback when no OIDC URL", async () => {
vi.spyOn(actions, "getOidcSignOutUrl").mockResolvedValue("/signin");
vi.spyOn(actions, "clearOidcTokenAction").mockResolvedValue(undefined);
mockAuthClientSignOut.mockResolvedValue(undefined);

const { signOut } = await import("../auth-client");

await signOut();

expect(mockLocationReplace).toHaveBeenCalledWith("/signin");
});
});
41 changes: 33 additions & 8 deletions src/lib/auth/__tests__/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,14 @@ describe("auth", () => {

it("should return null when token is expired", async () => {
const expiredTokenData: OidcTokenData = {
id: "expired-token-id",
createdAt: new Date(),
updatedAt: new Date(),
providerId: "provider-id",
accountId: "account-id",
accessToken: "expired-token",
userId: "user-123",
expiresAt: Date.now() - 1000, // Expired 1 second ago
accessTokenExpiresAt: Date.now() - 1000, // Expired 1 second ago
};

const encryptedPayload = await encrypt(
Expand All @@ -102,9 +107,14 @@ describe("auth", () => {

it("should return null when token belongs to different user", async () => {
const tokenData: OidcTokenData = {
id: "valid-token-id",
createdAt: new Date(),
updatedAt: new Date(),
providerId: "provider-id",
accountId: "account-id",
accessToken: "valid-token",
userId: "user-456", // Different user
expiresAt: Date.now() + 3600000,
accessTokenExpiresAt: Date.now() + 3600000,
};

const encryptedPayload = await encrypt(
Expand All @@ -120,9 +130,14 @@ describe("auth", () => {

it("should return access token when valid", async () => {
const tokenData: OidcTokenData = {
id: "valid-token-id",
createdAt: new Date(),
updatedAt: new Date(),
providerId: "provider-id",
accountId: "account-id",
accessToken: "valid-access-token-123",
userId: "user-123",
expiresAt: Date.now() + 3600000, // Valid for 1 hour
accessTokenExpiresAt: Date.now() + 3600000, // Valid for 1 hour
};

const encryptedPayload = await encrypt(
Expand All @@ -138,7 +153,7 @@ describe("auth", () => {

it("should return null when token data is invalid", async () => {
// Create invalid token data (missing required fields)
const invalidData = { accessToken: "token" }; // Missing userId and expiresAt
const invalidData = { accessToken: "token" }; // Missing userId and accessTokenExpiresAt
const invalidPayload = await encrypt(
invalidData as OidcTokenData,
process.env.BETTER_AUTH_SECRET as string,
Expand Down Expand Up @@ -181,26 +196,36 @@ describe("auth", () => {
describe("OidcTokenData Type Guard", () => {
it("should validate correct OidcTokenData structure", () => {
const validData: OidcTokenData = {
id: "valid-token-id",
createdAt: new Date(),
updatedAt: new Date(),
providerId: "provider-id",
accountId: "account-id",
accessToken: "token",
userId: "user-123",
expiresAt: Date.now() + 3600000,
accessTokenExpiresAt: Date.now() + 3600000,
refreshToken: "refresh-token",
};

// Type guard is private, so we test indirectly through getOidcProviderAccessToken
expect(validData).toHaveProperty("accessToken");
expect(validData).toHaveProperty("userId");
expect(validData).toHaveProperty("expiresAt");
expect(validData).toHaveProperty("accessTokenExpiresAt");
expect(typeof validData.accessToken).toBe("string");
expect(typeof validData.userId).toBe("string");
expect(typeof validData.expiresAt).toBe("number");
expect(typeof validData.accessTokenExpiresAt).toBe("number");
});

it("should handle optional refreshToken", () => {
const dataWithoutRefresh: OidcTokenData = {
id: "valid-token-id",
createdAt: new Date(),
updatedAt: new Date(),
providerId: "provider-id",
accountId: "account-id",
accessToken: "token",
userId: "user-123",
expiresAt: Date.now() + 3600000,
accessTokenExpiresAt: Date.now() + 3600000,
};

expect(dataWithoutRefresh.refreshToken).toBeUndefined();
Expand Down
Loading
Loading