Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
*/
package io.trino.gateway.ha.config;

import io.airlift.units.Duration;

import java.util.concurrent.TimeUnit;

public class FormAuthConfiguration
{
private SelfSignKeyPairConfiguration selfSignKeyPair;
private String ldapConfigPath;
private Duration sessionTimeout = new Duration(30, TimeUnit.MINUTES);

public FormAuthConfiguration(SelfSignKeyPairConfiguration selfSignKeyPair, String ldapConfigPath)
{
Expand Down Expand Up @@ -45,4 +50,14 @@ public void setLdapConfigPath(String ldapConfigPath)
{
this.ldapConfigPath = ldapConfigPath;
}

public Duration getSessionTimeout()
{
return this.sessionTimeout;
}

public void setSessionTimeout(Duration sessionTimeout)
{
this.sessionTimeout = sessionTimeout;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,18 @@ else if (oauthManager != null) {
}
return Response.ok(Result.ok("Ok", loginType)).build();
}

@POST
@Path("serverInfo")
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public Response serverInfo()
{
long serverStartTime = System.currentTimeMillis();
if (formAuthManager != null) {
serverStartTime = formAuthManager.getServerStartTime();
}
Map<String, Object> serverInfo = Map.of("serverStart", serverStartTime);
return Response.ok(Result.ok(serverInfo)).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.trino.gateway.ha.config.FormAuthConfiguration;
import io.trino.gateway.ha.config.LdapConfiguration;
import io.trino.gateway.ha.config.UserConfiguration;
import io.trino.gateway.ha.domain.Result;
import io.trino.gateway.ha.domain.request.RestLoginRequest;
import io.trino.gateway.ha.security.util.BasicCredentials;

import java.time.Instant;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
Expand All @@ -38,13 +42,16 @@
public class LbFormAuthManager
{
private static final Logger log = Logger.get(LbFormAuthManager.class);
private static final long SERVER_START_TIME = System.currentTimeMillis();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is intended to be shared by multiple classes, we can move it out from this class and then create a separate config class. OR we can create a getter for this field

private static final Duration DEFAULT_SESSION_TIMEOUT = new Duration(30, TimeUnit.MINUTES);
/**
* Cookie key to pass the token.
*/
private final LbKeyProvider lbKeyProvider;
private final Map<String, UserConfiguration> presetUsers;
private final Map<String, String> pagePermissions;
private final LbLdapClient lbLdapClient;
private final Duration sessionTimeout;

public LbFormAuthManager(FormAuthConfiguration configuration,
Map<String, UserConfiguration> presetUsers,
Expand All @@ -58,9 +65,12 @@ public LbFormAuthManager(FormAuthConfiguration configuration,
if (configuration != null) {
this.lbKeyProvider = new LbKeyProvider(configuration
.getSelfSignKeyPair());
this.sessionTimeout = configuration.getSessionTimeout() != null ?
configuration.getSessionTimeout() : DEFAULT_SESSION_TIMEOUT;
}
else {
this.lbKeyProvider = null;
this.sessionTimeout = DEFAULT_SESSION_TIMEOUT;
}

if (configuration != null && configuration.getLdapConfigPath() != null) {
Expand Down Expand Up @@ -105,6 +115,21 @@ public Optional<Map<String, Claim>> getClaimsFromIdToken(String idToken)
DecodedJWT jwt = JWT.decode(idToken);

if (LbTokenUtil.validateToken(idToken, lbKeyProvider.getRsaPublicKey(), jwt.getIssuer(), Optional.empty())) {
// Check if token was issued before server restart
Optional<Claim> serverStartClaim = Optional.ofNullable(jwt.getClaim("server_start"));
if (serverStartClaim.isPresent() && !serverStartClaim.orElseThrow().isNull()) {
long tokenServerStart = serverStartClaim.orElseThrow().asLong();
if (tokenServerStart != SERVER_START_TIME) {
log.info("Token invalidated due to server restart");
return Optional.empty();
}
}
// Check token expiration
Optional<Date> expiresAt = Optional.ofNullable(jwt.getExpiresAt());
if (expiresAt.isPresent() && expiresAt.orElseThrow().before(new Date())) {
log.info("Token expired");
return Optional.empty();
}
return Optional.of(jwt.getClaims());
}
}
Expand All @@ -124,10 +149,16 @@ private String getSelfSignedToken(String username)

Map<String, Object> headers = Map.of("alg", "RS256");

Instant now = Instant.now();
Instant expiration = now.plusSeconds(sessionTimeout.roundTo(TimeUnit.SECONDS));

token = JWT.create()
.withHeader(headers)
.withIssuer(SessionCookie.SELF_ISSUER_ID)
.withSubject(username)
.withIssuedAt(Date.from(now))
.withExpiresAt(Date.from(expiration))
.withClaim("server_start", SERVER_START_TIME)
.sign(algorithm);
}
catch (JWTCreationException exception) {
Expand Down Expand Up @@ -167,4 +198,9 @@ public List<String> processPagePermissions(List<String> roles)
.flatMap(role -> Stream.of(pagePermissions.get(role).split("_")))
.distinct().toList();
}

public long getServerStartTime()
{
return SERVER_START_TIME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public static NewCookie getTokenCookie(String token)
.path("/")
.domain("")
.comment("")
.maxAge(60 * 60 * 24)
.maxAge(60 * 15) // 15 minutes session timeout
.secure(true)
.build();
}
Expand Down
13 changes: 13 additions & 0 deletions webapp/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { useEffect } from 'react';
import { getCSSVar } from './utils/utils';
import { IllustrationIdle, IllustrationIdleDark } from '@douyinfe/semi-illustrations';
import Cookies from 'js-cookie';
import { SessionManager } from './utils/session';

function App() {
return (
Expand All @@ -40,6 +41,18 @@ function Screen() {
access.updateToken(token);
Cookies.remove('token');
}
// Initialize session management
const sessionManager = SessionManager.getInstance();
sessionManager.setSessionExpiredCallback(() => {
access.logout();
});

// Check token validity on app start
access.checkTokenValidity().catch(console.error);

return () => {
sessionManager.clearTimeout();
};
}, [])
return (
<>
Expand Down
34 changes: 33 additions & 1 deletion webapp/src/api/base.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import { useAccessStore } from "../store";
import Locale, { getServerLang } from "../locales";
import { Toast } from "@douyinfe/semi-ui";
import { SessionManager } from "../utils/session";

export class ClientApi {
async get(url: string, params: Record<string, any> = {}): Promise<any> {
// Check token validity before making request
await this.validateTokenBeforeRequest(url);
let queryString = "";
if (Object.keys(params).length > 0) {
queryString = "?" + new URLSearchParams(params).toString();
Expand Down Expand Up @@ -40,6 +43,8 @@ export class ClientApi {
}

async post(url: string, body: Record<string, any> = {}): Promise<any> {
// Check token validity before making request
await this.validateTokenBeforeRequest(url);
const res: Response = await fetch(
this.path(url),
{
Expand Down Expand Up @@ -76,6 +81,8 @@ export class ClientApi {
}

async postForm(url: string, formData: FormData = new FormData()): Promise<any> {
// Check token validity before making request
await this.validateTokenBeforeRequest(url);
const res: Response = await fetch(
this.path(url),
{
Expand Down Expand Up @@ -104,6 +111,26 @@ export class ClientApi {
return resJson.data;
}

private async validateTokenBeforeRequest(url: string): Promise<void> {
// Skip validation for login-related endpoints to avoid infinite loops
if (url.includes('/login') || url.includes('/serverInfo') || url.includes('/loginType')) {
return;
}

const accessStore = useAccessStore.getState();
if (accessStore.token) {
try {
const isValid = await accessStore.checkTokenValidity();
if (!isValid) {
throw new Error('Token validation failed');
}
} catch (error) {
// Token validation failed, user will be logged out
throw error;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we doing anything in catch? Otherwise it may be unnecessary to catch the error

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andythsu yes, we are throwing error as token validation failed and user will be logged out in that case

}
}
}

path(path: string): string {
const proxyPath = import.meta.env.VITE_PROXY_PATH;
return [proxyPath, path].join("");
Expand Down Expand Up @@ -134,7 +161,12 @@ export function getHeaders(): Record<string, string> {
const validString = (x: string) => x && x.length > 0;

if (validString(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
// For synchronous header generation, we'll do basic token validation
// The async server restart check will happen in the session manager
const sessionManager = SessionManager.getInstance();
if (!sessionManager.isTokenExpired(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
}
}

return headers;
Expand Down
4 changes: 4 additions & 0 deletions webapp/src/api/webapp/login.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ export async function loginTypeApi(): Promise<any> {
export async function getUIConfiguration(): Promise<any> {
return api.get('/webapp/getUIConfiguration')
}

export async function serverInfoApi(): Promise<any> {
return api.post('/serverInfo', {})
}
50 changes: 49 additions & 1 deletion webapp/src/store/access.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { create } from "zustand";
import { persist } from "zustand/middleware";
import { StoreKey } from "../constant";
import { getInfoApi } from "../api/webapp/login";
import { getInfoApi, serverInfoApi } from "../api/webapp/login";
import { SessionManager } from "../utils/session";

export enum Role {
ADMIN = "ADMIN",
Expand All @@ -28,6 +29,8 @@ export interface AccessControlStore {
getUserInfo: (_?: boolean) => void;
hasRole: (role: Role) => boolean;
hasPermission: (permission: string | undefined) => boolean;
logout: () => void;
checkTokenValidity: () => Promise<boolean>;
}

let fetchState: number = 0; // 0 not fetch, 1 fetching, 2 done
Expand Down Expand Up @@ -78,6 +81,51 @@ export const useAccessStore = create<AccessControlStore>()(
const permissions = get().permissions
return permission == undefined || permissions == null || permissions.length == 0 || permissions.includes(permission);
},
logout() {
const sessionManager = SessionManager.getInstance();
sessionManager.clearTimeout();
set(() => ({
token: "",
userId: "",
userName: "",
nickName: "",
userType: "",
email: "",
phonenumber: "",
sex: "",
avatar: "",
permissions: [],
roles: [],
}));
fetchState = 0;
},
async checkTokenValidity() {
const token = get().token;
if (!token) return false;

const sessionManager = SessionManager.getInstance();

// Check if token is expired
if (sessionManager.isTokenExpired(token)) {
get().logout();
return false;
}

// Check for server restart
try {
const serverInfo = await serverInfoApi();
if (sessionManager.checkServerRestart(token, serverInfo.serverStart)) {
console.log('Server restart detected, logging out');
get().logout();
return false;
}
} catch (error) {
console.error('Error checking server info:', error);
// Don't logout on API error, just continue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we log out here? technically we should never end up in this state, but if we do, it means the server is having issues.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andythsu No, we are not logging out here, we are logging the server error here

}

return true;
},
}),
{
name: StoreKey.Access,
Expand Down
Loading