From 932d14f55e026cfab8f9470354ee4258a331b9e2 Mon Sep 17 00:00:00 2001 From: abs51295 Date: Fri, 6 Jul 2018 14:58:38 +0530 Subject: [PATCH] Add service token authentication mechanism --- openshift/template.yaml | 5 +++ src/auth.py | 70 ++++++++++++++++++++++++++++++++++++++--- src/rest_api.py | 9 +++++- src/utils.py | 25 +++++++++++++++ tests/test_auth.py | 12 +++---- 5 files changed, 110 insertions(+), 11 deletions(-) diff --git a/openshift/template.yaml b/openshift/template.yaml index 63a349b..9a0c698 100644 --- a/openshift/template.yaml +++ b/openshift/template.yaml @@ -95,6 +95,11 @@ objects: configMapKeyRef: name: bayesian-config key: keycloak-url + - name: BAYESIAN_AUTH_PUBLIC_KEYS_URL + valueFrom: + configMapKeyRef: + name: bayesian-config + key: auth-url - name: BAYESIAN_JWT_AUDIENCE value: "fabric8-online-platform,openshiftio-public" image: "${DOCKER_REGISTRY}/${DOCKER_IMAGE}:${IMAGE_TAG}" diff --git a/src/auth.py b/src/auth.py index fe3dd02..ed2a86f 100644 --- a/src/auth.py +++ b/src/auth.py @@ -5,12 +5,11 @@ import jwt from os import getenv - from exceptions import HTTPError -from utils import fetch_public_key +from utils import fetch_public_key, fetch_service_public_keys -def decode_token(token): +def decode_user_token(token): """Decode the authorization token read from the request header.""" if token is None: return {} @@ -38,6 +37,39 @@ def decode_token(token): return decoded_token +def decode_service_token(token): # pragma: no cover + """Decode OSIO service token.""" + # TODO: Merge this function and user token function once audience is removed from user tokens. + if token is None: + return {} + + if token.startswith('Bearer '): + _, token = token.split(' ', 1) + + pub_keys = fetch_service_public_keys(current_app) + decoded_token = None + + # Since we have multiple public keys, we need to verify against every public key. + # Token can be decoded by any one of the available public keys. + for pub_key in pub_keys: + try: + pub_key = '-----BEGIN PUBLIC KEY-----\n{pkey}\n-----END PUBLIC KEY-----'\ + .format(pkey=pub_key) + decoded_token = jwt.decode(token, pub_key, algorithms=['RS256']) + except jwt.InvalidTokenError: + current_app.logger.error("Auth token couldn't be decoded for public key: {}" + .format(pub_key)) + decoded_token = None + + if decoded_token: + break + + if not decoded_token: + raise jwt.InvalidTokenError('Auth token cannot be verified.') + + return decoded_token + + def get_token_from_auth_header(): """Get the authorization token read from the request header.""" return request.headers.get('Authorization') @@ -62,7 +94,37 @@ def wrapper(*args, **kwargs): lgr = current_app.logger try: - decoded = decode_token(get_token_from_auth_header()) + decoded = decode_user_token(get_token_from_auth_header()) + if not decoded: + lgr.exception('Provide an Authorization token with the API request') + raise HTTPError(401, 'Authentication failed - token missing') + + lgr.info('Successfuly authenticated user {e} using JWT'. + format(e=decoded.get('email'))) + except jwt.ExpiredSignatureError as exc: + lgr.exception('Expired JWT token') + raise HTTPError(401, 'Authentication failed - token has expired') from exc + except Exception as exc: + lgr.exception('Failed decoding JWT token') + raise HTTPError(401, 'Authentication failed - could not decode JWT token') from exc + + return view(*args, **kwargs) + + return wrapper + + +def service_token_required(view): # pragma: no cover + """Check if the request contains a valid service token.""" + @wraps(view) + def wrapper(*args, **kwargs): + # Disable authentication for local setup + if getenv('DISABLE_AUTHENTICATION') in ('1', 'True', 'true'): + return view(*args, **kwargs) + + lgr = current_app.logger + + try: + decoded = decode_service_token(get_token_from_auth_header()) if not decoded: lgr.exception('Provide an Authorization token with the API request') raise HTTPError(401, 'Authentication failed - token missing') diff --git a/src/rest_api.py b/src/rest_api.py index 53f3a21..e07fbf9 100644 --- a/src/rest_api.py +++ b/src/rest_api.py @@ -4,7 +4,7 @@ from flask_cors import CORS from utils import DatabaseIngestion, scan_repo, validate_request_data, retrieve_worker_result from f8a_worker.setup_celery import init_selinon -from auth import login_required +from auth import login_required, service_token_required from exceptions import HTTPError app = Flask(__name__) @@ -215,5 +215,12 @@ def handle_error(e): # pragma: no cover }), e.status_code +@app.route('/test-service-token') +@service_token_required +def test_service_token(): # pragma: no cover + """Test the service token authentication mechanism.""" + return flask.jsonify({'token': 'is_valid'}), 200 + + if __name__ == "__main__": # pragma: no cover app.run() diff --git a/src/utils.py b/src/utils.py index fd8a9ba..05af467 100644 --- a/src/utils.py +++ b/src/utils.py @@ -292,3 +292,28 @@ def fetch_public_key(app): app.public_key = None return app.public_key + + +def fetch_service_public_keys(app): # pragma: no cover + """Get public keys for OSIO service account. Currently, there are three public keys.""" + if not getattr(app, "service_public_keys", []): + auth_url = os.getenv('BAYESIAN_AUTH_PUBLIC_KEYS_URL', '') + if auth_url: + try: + auth_url = auth_url.strip('/') + '/api/token/keys?format=pem' + result = requests.get(auth_url, timeout=0.5) + app.logger.info('Fetching public key from %s, status %d, result: %s', + auth_url, result.status_code, result.text) + except requests.exceptions.Timeout: + app.logger.error('Timeout fetching public key from %s', auth_url) + return '' + if result.status_code != 200: + return '' + + keys = result.json().get('keys', []) + app.service_public_keys = keys + + else: + app.service_public_keys = None + + return app.service_public_keys diff --git a/tests/test_auth.py b/tests/test_auth.py index daaeb55..bf8d9ac 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -68,7 +68,7 @@ def mocked_get_audiences_3(): @patch("auth.fetch_public_key", side_effect=mocked_fetch_public_key_1) def test_decode_token_invalid_input_1(mocked_fetch_public_key, mocked_get_audiences): """Test the invalid input handling during token decoding.""" - assert decode_token(None) == {} + assert decode_user_token(None) == {} @patch("auth.get_audiences", side_effect=mocked_get_audiences) @@ -76,7 +76,7 @@ def test_decode_token_invalid_input_1(mocked_fetch_public_key, mocked_get_audien def test_decode_token_invalid_input_2(mocked_fetch_public_key, mocked_get_audiences): """Test the invalid input handling during token decoding.""" with pytest.raises(Exception): - assert decode_token("Foobar") is None + assert decode_user_token("Foobar") is None @patch("auth.get_audiences", side_effect=mocked_get_audiences) @@ -84,7 +84,7 @@ def test_decode_token_invalid_input_2(mocked_fetch_public_key, mocked_get_audien def test_decode_token_invalid_input_3(mocked_fetch_public_key, mocked_get_audiences): """Test the invalid input handling during token decoding.""" with pytest.raises(Exception): - assert decode_token("Bearer ") is None + assert decode_user_token("Bearer ") is None @patch("auth.get_audiences", side_effect=mocked_get_audiences) @@ -92,7 +92,7 @@ def test_decode_token_invalid_input_3(mocked_fetch_public_key, mocked_get_audien def test_decode_token_invalid_input_4(mocked_fetch_public_key, mocked_get_audiences): """Test the invalid input handling during token decoding.""" with pytest.raises(Exception): - assert decode_token("Bearer ") is None + assert decode_user_token("Bearer ") is None @patch("auth.get_audiences", side_effect=mocked_get_audiences_2) @@ -100,7 +100,7 @@ def test_decode_token_invalid_input_4(mocked_fetch_public_key, mocked_get_audien def test_decode_token_invalid_input_5(mocked_fetch_public_key, mocked_get_audiences): """Test the handling wrong JWT tokens.""" with pytest.raises(Exception): - assert decode_token("Bearer something") is None + assert decode_user_token("Bearer something") is None @patch("auth.get_audiences", side_effect=mocked_get_audiences_3) @@ -112,7 +112,7 @@ def test_decode_token_invalid_input_6(mocked_fetch_public_key, mocked_get_audien 'aud': 'foo:bar' } token = jwt.encode(payload, PRIVATE_KEY, algorithm='RS256').decode("utf-8") - assert decode_token(token) is not None + assert decode_user_token(token) is not None def test_audiences():