Skip to content

Commit 7214900

Browse files
committed
fix for floconsole and kms
1 parent e6a251d commit 7214900

5 files changed

Lines changed: 226 additions & 10 deletions

File tree

wavefront/server/apps/floconsole/floconsole/services/token_service.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
self.is_dev = app_env == 'dev' or (kms_service is None)
4040
self.private_key = self._load_key(private_key)
4141
self.public_key = self._load_key(public_key)
42-
self.algorithm = TokenAlgorithms.RS256.value if self.is_dev else algorithm.value
42+
self.algorithm = self._resolve_algorithm(kms_service, algorithm, self.is_dev)
4343
self.token_expiry = int(token_expiry)
4444
self.temporary_token_expiry = int(temporary_token_expiry)
4545
self.kms_service = kms_service
@@ -51,6 +51,30 @@ def _load_key(self, key: str):
5151
key = base64.b64decode(key).decode('ascii')
5252
return key
5353

54+
@staticmethod
55+
def _resolve_algorithm(
56+
kms_service: FloKMS | None,
57+
configured: TokenAlgorithms,
58+
is_dev: bool,
59+
) -> str:
60+
if is_dev:
61+
return TokenAlgorithms.RS256.value
62+
if kms_service is not None:
63+
getter = getattr(kms_service, 'jwt_algorithm', None)
64+
if callable(getter):
65+
return getter()
66+
return configured.value
67+
68+
def _jwt_decode_algorithms(self) -> list[str]:
69+
"""Allow legacy PS256 headers on RS256 (PKCS1) KMS signatures."""
70+
algorithms = [self.algorithm]
71+
if (
72+
self.algorithm == TokenAlgorithms.RS256.value
73+
and TokenAlgorithms.PS256.value not in algorithms
74+
):
75+
algorithms.append(TokenAlgorithms.PS256.value)
76+
return algorithms
77+
5478
def create_token(
5579
self,
5680
sub: str | None = None,
@@ -137,14 +161,14 @@ def decode_token(self, token: str) -> dict:
137161

138162
is_valid = self.kms_service.verify(message=digest, signature=signature)
139163
if not is_valid:
140-
return {}
164+
raise ValueError('Invalid token signature')
141165

142166
public_key_pem = self.kms_service.get_public_key_pem()
143167

144168
decoded = jwt.decode(
145169
clean_token,
146170
public_key_pem,
147-
algorithms=[self.algorithm],
171+
algorithms=self._jwt_decode_algorithms(),
148172
issuer=self.issuer,
149173
audience=self.audience,
150174
)

wavefront/server/modules/auth_module/auth_module/services/token_service.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
self.is_dev = app_env == 'dev' or (kms_service is None)
3939
self.private_key = self._load_key(private_key)
4040
self.public_key = self._load_key(public_key)
41-
self.algorithm = TokenAlgorithms.RS256.value if self.is_dev else algorithm.value
41+
self.algorithm = self._resolve_algorithm(kms_service, algorithm, self.is_dev)
4242
self.token_expiry = int(token_expiry)
4343
self.temporary_token_expiry = int(temporary_token_expiry)
4444
self.kms_service = kms_service
@@ -49,6 +49,29 @@ def _load_key(self, key: str):
4949
key = base64.b64decode(key).decode('ascii')
5050
return key
5151

52+
@staticmethod
53+
def _resolve_algorithm(
54+
kms_service: FloKMS,
55+
configured: TokenAlgorithms,
56+
is_dev: bool,
57+
) -> str:
58+
if is_dev:
59+
return TokenAlgorithms.RS256.value
60+
if kms_service is not None:
61+
getter = getattr(kms_service, 'jwt_algorithm', None)
62+
if callable(getter):
63+
return getter()
64+
return configured.value
65+
66+
def _jwt_decode_algorithms(self) -> list[str]:
67+
algorithms = [self.algorithm]
68+
if (
69+
self.algorithm == TokenAlgorithms.RS256.value
70+
and TokenAlgorithms.PS256.value not in algorithms
71+
):
72+
algorithms.append(TokenAlgorithms.PS256.value)
73+
return algorithms
74+
5275
def create_token(
5376
self,
5477
sub: str | None = None,
@@ -118,14 +141,14 @@ def decode_token(self, token: str) -> dict:
118141

119142
is_valid = self.kms_service.verify(message=digest, signature=signature)
120143
if not is_valid:
121-
return {}
144+
raise ValueError('Invalid token signature')
122145

123146
public_key_pem = self.kms_service.get_public_key_pem()
124147

125148
decoded = jwt.decode(
126149
token,
127150
public_key_pem,
128-
algorithms=[self.algorithm],
151+
algorithms=self._jwt_decode_algorithms(),
129152
issuer=self.issuer,
130153
audience=self.audience,
131154
)

wavefront/server/packages/flo_cloud/flo_cloud/gcp/kms.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@
1616
gcp_crypto_key = os.getenv('GCP_KMS_CRYPTO_KEY')
1717
gcp_crypto_key_version = os.getenv('GCP_KMS_CRYPTO_KEY_VERSION')
1818

19+
# GCP KMS PKCS#1 v1.5 signing algorithms (JWT alg RS256)
20+
_PKCS1_ALGORITHMS = frozenset(
21+
{
22+
kms_v1.CryptoKeyVersion.CryptoKeyVersionAlgorithm.RSA_SIGN_PKCS1_2048_SHA256,
23+
kms_v1.CryptoKeyVersion.CryptoKeyVersionAlgorithm.RSA_SIGN_PKCS1_3072_SHA256,
24+
kms_v1.CryptoKeyVersion.CryptoKeyVersionAlgorithm.RSA_SIGN_PKCS1_4096_SHA256,
25+
}
26+
)
27+
1928

2029
class GcpKMS(FloKMS):
2130
def __init__(self):
@@ -40,6 +49,11 @@ def __init__(self):
4049
crypto_key=gcp_crypto_key,
4150
crypto_key_version=gcp_crypto_key_version,
4251
)
52+
public_key = self.kms_client.get_public_key(
53+
request=kms_v1.GetPublicKeyRequest(name=self.key_name)
54+
)
55+
self._key_algorithm = public_key.algorithm
56+
self._uses_pkcs1 = self._key_algorithm in _PKCS1_ALGORITHMS
4357

4458
def encrypt(self, plaintext: str) -> bytes:
4559
request = kms_v1.EncryptRequest(
@@ -68,20 +82,29 @@ def sign(self, message: bytes, **kwargs) -> bytes:
6882
response = self.kms_client.asymmetric_sign(request=request)
6983
return response.signature
7084

85+
def jwt_algorithm(self) -> str:
86+
"""JWT alg header matching this KMS key (RS256 for PKCS1 keys, PS256 for PSS)."""
87+
return 'RS256' if self._uses_pkcs1 else 'PS256'
88+
7189
def verify(self, message: bytes, signature: bytes, **kwargs) -> bool:
7290
public_key_pem: bytes | str = self.get_public_key_pem(encode=True)
7391
if isinstance(public_key_pem, str):
7492
raise ValueError('Public key is not a bytes object')
7593
rsa_key = serialization.load_pem_public_key(public_key_pem, default_backend())
7694

95+
if self._uses_pkcs1:
96+
verify_padding = padding.PKCS1v15()
97+
else:
98+
verify_padding = padding.PSS(
99+
mgf=padding.MGF1(hashes.SHA256()),
100+
salt_length=padding.PSS.MAX_LENGTH,
101+
)
102+
77103
try:
78104
rsa_key.verify( # type: ignore
79105
signature=signature,
80106
data=message,
81-
padding=padding.PSS( # type: ignore
82-
mgf=padding.MGF1(hashes.SHA256()),
83-
salt_length=padding.PSS.MAX_LENGTH,
84-
),
107+
padding=verify_padding,
85108
algorithm=utils.Prehashed(hashes.SHA256()), # type: ignore
86109
)
87110
return True

wavefront/server/packages/flo_cloud/flo_cloud/kms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@ def verify(self, message: bytes, signature: bytes, **kwargs) -> bool:
3737

3838
def get_public_key_pem(self, **kwargs) -> bytes | str:
3939
return self.kms_client.get_public_key_pem(**kwargs)
40+
41+
def jwt_algorithm(self) -> str:
42+
getter = getattr(self.kms_client, 'jwt_algorithm', None)
43+
if callable(getter):
44+
return getter()
45+
return 'PS256'
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simulate /floconsole/v1/authenticate token create + require_auth decode (KMS).
4+
5+
Usage:
6+
export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json
7+
export GCP_PROJECT_ID=... GCP_LOCATION=... GCP_KMS_KEY_RING=...
8+
export GCP_KMS_CRYPTO_KEY=... GCP_KMS_CRYPTO_KEY_VERSION=...
9+
uv run python scripts/test_kms_auth_flow.py
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import base64
15+
import os
16+
import subprocess
17+
import sys
18+
import tempfile
19+
from uuid import uuid4
20+
21+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../packages/flo_cloud'))
22+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../apps/floconsole'))
23+
24+
from flo_cloud.gcp.kms import GcpKMS
25+
from flo_cloud.kms import FloKmsService
26+
from floconsole.constants.auth import AUTH_ROLE_ID
27+
from floconsole.services.token_service import TokenAlgorithms, TokenService
28+
29+
ISSUER = os.getenv('CONSOLE_JWT_ISSUER', 'https://floconsole.rootflo.ai')
30+
AUDIENCE = os.getenv('CONSOLE_JWT_AUDIENCE', 'https://floconsole.rootflo.ai')
31+
PREFIX = os.getenv('CONSOLE_TOKEN_PREFIX', 'fc_')
32+
33+
34+
def _dummy_pem_keys() -> tuple[str, str]:
35+
with tempfile.NamedTemporaryFile(suffix='.pem', delete=False) as priv:
36+
subprocess.run(
37+
['openssl', 'genrsa', '-out', priv.name, '2048'],
38+
check=True,
39+
capture_output=True,
40+
)
41+
priv_pem = open(priv.name, 'rb').read()
42+
pub_proc = subprocess.run(
43+
['openssl', 'rsa', '-pubout'],
44+
input=priv_pem,
45+
capture_output=True,
46+
check=True,
47+
)
48+
return base64.b64encode(priv_pem).decode(), base64.b64encode(
49+
pub_proc.stdout
50+
).decode()
51+
52+
53+
def _simulate_require_auth(decoded: dict) -> str | None:
54+
"""Mirror floconsole require_auth checks after decode_token."""
55+
if 'session_id' not in decoded:
56+
return 'Invalid token: missing session_id'
57+
if 'role_id' not in decoded or decoded['role_id'] != AUTH_ROLE_ID:
58+
return 'Invalid token: Not the console user'
59+
return None
60+
61+
62+
def main() -> int:
63+
print('=== KMS auth flow test (create_token + decode_token) ===\n')
64+
65+
for var in (
66+
'GCP_PROJECT_ID',
67+
'GCP_LOCATION',
68+
'GCP_KMS_KEY_RING',
69+
'GCP_KMS_CRYPTO_KEY',
70+
'GCP_KMS_CRYPTO_KEY_VERSION',
71+
'GOOGLE_APPLICATION_CREDENTIALS',
72+
):
73+
print(f' {var}={os.environ.get(var, "<not set>")}')
74+
75+
print('\n--- Step 1: Init KMS (same as ApplicationContainer) ---')
76+
kms = FloKmsService(cloud_provider='gcp')
77+
gcp: GcpKMS = kms.kms_client # type: ignore[assignment]
78+
print(f' KMS key: {gcp.key_name}')
79+
print(f' jwt_algorithm(): {kms.jwt_algorithm()}')
80+
print(f' uses_pkcs1: {gcp._uses_pkcs1}')
81+
82+
priv, pub = _dummy_pem_keys()
83+
token_service = TokenService(
84+
private_key=priv,
85+
public_key=pub,
86+
kms_service=kms,
87+
algorithm=TokenAlgorithms.PS256,
88+
app_env='production',
89+
token_prefix=PREFIX,
90+
issuer=ISSUER,
91+
audience=AUDIENCE,
92+
)
93+
print('\n--- Step 2: TokenService (production / KMS) ---')
94+
print(f' is_dev={token_service.is_dev}')
95+
print(f' algorithm={token_service.algorithm}')
96+
97+
session_id = str(uuid4())
98+
user_id = str(uuid4())
99+
print('\n--- Step 3: create_token (POST /authenticate) ---')
100+
token = token_service.create_token(
101+
sub='admin@rootflo.ai',
102+
user_id=user_id,
103+
role_id=AUTH_ROLE_ID,
104+
payload={'session_id': session_id},
105+
)
106+
print(f' token length={len(token)}')
107+
print(f' prefix ok={token.startswith(PREFIX)}')
108+
header_alg = __import__('json').loads(
109+
base64.urlsafe_b64decode(token[len(PREFIX) :].split('.')[0] + '==')
110+
)['alg']
111+
print(f' JWT header alg={header_alg}')
112+
113+
print('\n--- Step 4: decode_token (require_auth middleware) ---')
114+
try:
115+
decoded = token_service.decode_token(token)
116+
except ValueError as e:
117+
print(f' FAIL ValueError: {e}')
118+
return 1
119+
except Exception as e:
120+
print(f' FAIL {type(e).__name__}: {e}')
121+
return 1
122+
123+
print(f' decoded session_id={decoded.get("session_id")}')
124+
print(f' decoded role_id={decoded.get("role_id")}')
125+
print(f' decoded iss={decoded.get("iss")}')
126+
127+
err = _simulate_require_auth(decoded)
128+
if err:
129+
print('\n--- Step 5: require_auth ---')
130+
print(f' FAIL: {err}')
131+
return 1
132+
133+
print('\n--- Step 5: require_auth ---')
134+
print(' OK: token would be accepted')
135+
print('\n=== PASS: full KMS create + validate flow ===')
136+
return 0
137+
138+
139+
if __name__ == '__main__':
140+
sys.exit(main())

0 commit comments

Comments
 (0)