Skip to content

Commit

Permalink
[DI-466] Add error handling if database credentials env is missing (#38)
Browse files Browse the repository at this point in the history
* Add error handling for db connection
  • Loading branch information
alicia-koh authored Apr 11, 2024
1 parent e136ba9 commit 611ae03
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 45 deletions.
97 changes: 55 additions & 42 deletions services/alp-dataflow-gen/pysrc/alpconnection/dbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,49 @@
from sqlalchemy import create_engine, text
import datetime
from uuid import UUID
from utils.types import DBCredentials, HANA_TENANT_USERS, PG_TENANT_USERS
from utils.types import DBCredentialsType, HANA_TENANT_USERS, PG_TENANT_USERS

# TODO: Remove after envConverter returns postgres
# catches possible pg dialect values from envConverter
POSTGRES_DIALECT_OPTIONS = ['postgres', 'postgresql', 'pg']


def GetDBConnection(database_code: str, user_type: str):
conn_details = extract_db_credentials(database_code)

if conn_details == {}:
raise Exception("No DB Credentials found!")
database_name = conn_details["databaseName"]
if conn_details["dialect"] == "hana":
dialect_driver = "hana+hdbcli"
encrypt = conn_details["encrypt"]
validateCertificate = conn_details["validateCertificate"]
db = database_name + \
f"?encrypt={encrypt}?validateCertificate={validateCertificate}"
elif conn_details['dialect'] in POSTGRES_DIALECT_OPTIONS:
dialect_driver = "postgresql+psycopg2"
db = database_name

match user_type:
case HANA_TENANT_USERS.READ_USER:
databaseUser = conn_details["readUser"]
databasePassword = conn_details["readPassword"]
case HANA_TENANT_USERS.ADMIN_USER:
databaseUser = conn_details["adminUser"]
databasePassword = conn_details["adminPassword"]
case PG_TENANT_USERS.ADMIN_USER:
databaseUser = conn_details["adminUser"]
databasePassword = conn_details["adminPassword"]
case PG_TENANT_USERS.READ_USER:
databaseUser = conn_details["readUser"]
databasePassword = conn_details["readPassword"]

host = conn_details["host"]
port = conn_details["port"]
conn_string = _CreateConnectionString(
dialect_driver, databaseUser, databasePassword, host, port, db)
engine = create_engine(conn_string)
return engine
try:
conn_details = extract_db_credentials(database_code)
except Exception as e:
raise e
else:
database_name = conn_details["databaseName"]
if conn_details["dialect"] == "hana":
dialect_driver = "hana+hdbcli"
encrypt = conn_details["encrypt"]
validateCertificate = conn_details["validateCertificate"]
db = database_name + \
f"?encrypt={encrypt}?validateCertificate={validateCertificate}"
elif conn_details['dialect'] in POSTGRES_DIALECT_OPTIONS:
dialect_driver = "postgresql+psycopg2"
db = database_name
match user_type:
case HANA_TENANT_USERS.READ_USER:
databaseUser = conn_details["readUser"]
databasePassword = conn_details["readPassword"]
case HANA_TENANT_USERS.ADMIN_USER:
databaseUser = conn_details["adminUser"]
databasePassword = conn_details["adminPassword"]
case PG_TENANT_USERS.ADMIN_USER:
databaseUser = conn_details["adminUser"]
databasePassword = conn_details["adminPassword"]
case PG_TENANT_USERS.READ_USER:
databaseUser = conn_details["readUser"]
databasePassword = conn_details["readPassword"]

host = conn_details["host"]
port = conn_details["port"]
conn_string = _CreateConnectionString(
dialect_driver, databaseUser, databasePassword, host, port, db)
engine = create_engine(conn_string)
return engine


def GetConfigDBConnection(): # Single pre-configured postgres database
Expand Down Expand Up @@ -123,11 +123,19 @@ def db_svc_dialect_mapper(dialect: str) -> str:

def extract_db_credentials(database_code: str):
dbs = json.loads(os.environ["DATABASE_CREDENTIALS"])
for _db in dbs:
database_credential = DatabaseCredentials(_db)
if "alp-dataflow-gen" in database_credential.tags:
if database_credential.get_database_code() == database_code:
return database_credential.get_values()

if dbs == []:
raise Exception(
f"Database credentials environment variable is missing")
else:
_db = next(filter(lambda x: x["values"]
["code"] == database_code and "alp-dataflow-gen" in x["tags"], dbs), None)
if not _db:
raise Exception(
f"Database code {database_code} not found in database credentials")
else:
database_credential = DatabaseCredentials(_db)
return database_credential.get_values()


class DatabaseCredentials:
Expand All @@ -151,12 +159,17 @@ def get_validate_certificate(self) -> bool:
else:
return False

def get_values(self) -> DBCredentials:
def get_values(self) -> DBCredentialsType:
values = self.values["credentials"]
values["databaseName"] = self.values["databaseName"]
values["dialect"] = self.values["dialect"]
values["host"] = self.values["host"]
values["port"] = self.values["port"]
values["encrypt"] = self.get_encrypt()
values["validateCertificate"] = self.get_validate_certificate()
try:
# validate
DBCredentialsType(**values)
except Exception as e:
raise Exception(f"Failed validating database credentials values: {e}")
return values
4 changes: 1 addition & 3 deletions services/alp-dataflow-gen/pysrc/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from typing import Optional, List, Dict


class DBCredentials(BaseModel):
class DBCredentialsType(BaseModel):
adminPassword: str
adminUser: str
adminPasswordSalt: str
readPassword: str
readUser: str
readPasswordSalt: str
dialect: str
databaseName: str
host: str
Expand Down

0 comments on commit 611ae03

Please sign in to comment.