diff --git a/reframe/frontend/cli.py b/reframe/frontend/cli.py index 6fa179961..43a0a53e1 100644 --- a/reframe/frontend/cli.py +++ b/reframe/frontend/cli.py @@ -1299,10 +1299,18 @@ def print_infoline(param, value): print_infoline('output directory', repr(session_info['prefix_output'])) print_infoline('log files', ', '.join(repr(s) for s in session_info['log_files'])) + backend = rt.get_option('storage/0/backend') + if backend == 'sqlite': + dbfile = osext.expandvars(rt.get_option('storage/0/sqlite_db_file')) + dbinfo = f'sqlite file = {dbfile!r}' + elif backend == 'postgresql': + host = rt.get_option('storage/0/postgresql_host') + port = rt.get_option('storage/0/postgresql_port') + db = rt.get_option('storage/0/postgresql_db') + dbinfo = f'postgresql://{host}:{port}/{db}' print_infoline( 'results database', - f'[{storage_status}] ' - f'{osext.expandvars(rt.get_option("storage/0/sqlite_db_file"))!r}' + f'[{storage_status}] {dbinfo}' ) printer.info('') try: diff --git a/reframe/frontend/reporting/storage.py b/reframe/frontend/reporting/storage.py index 717574469..1f97a18b5 100644 --- a/reframe/frontend/reporting/storage.py +++ b/reframe/frontend/reporting/storage.py @@ -9,7 +9,22 @@ import json import os import re -import sqlite3 + +from sqlalchemy import (and_, + Column, + create_engine, + delete, + event, + Float, + ForeignKey, + Index, + MetaData, + select, + Table, + Text) +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.engine.url import URL +from sqlalchemy.sql.elements import ClauseElement import reframe.utility.jsonext as jsonext import reframe.utility.osext as osext @@ -20,6 +35,137 @@ from ..reporting.utility import QuerySelector +class _ConnectionStrategy: + '''Abstract helper class for building the URL and kwargs for a given SQL dialect''' + + def __init__(self): + self.url = self._build_connection_url() + self.engine = create_engine(self.url, **self._connection_kwargs) + + @abc.abstractmethod + def _build_connection_url(self): + '''Return a SQLAlchemy URL string for this dialect. + + Implementations must return a URL suitable for passing to + `sqlalchemy.create_engine()`. + ''' + + @property + def _connection_kwargs(self): + '''Per‑dialect kwargs for `create_engine()`''' + return {} + + @property + def json_column_type(self): + '''Return the JSON column type to use for JSON payloads''' + return Text + + @contextlib.contextmanager + def db_read(self, *args, **kwargs): + '''Default read context yields a transactional connection''' + with self.engine.begin() as conn: + yield conn + + @contextlib.contextmanager + def db_write(self, *args, **kwargs): + '''Default write context yields a transactional connection''' + with self.engine.begin() as conn: + yield conn + + +class _SqliteConnector(_ConnectionStrategy): + def __init__(self): + self.__db_file = os.path.join( + osext.expandvars(runtime().get_option('storage/0/sqlite_db_file')) + ) + mode = runtime().get_option( + 'storage/0/sqlite_db_file_mode' + ) + if not isinstance(mode, int): + self.__db_file_mode = int(mode, base=8) + else: + self.__db_file_mode = mode + + self.__db_lock = osext.ReadWriteFileLock( + os.path.join(os.path.dirname(self.__db_file), '.db.lock'), + self.__db_file_mode + ) + + prefix = os.path.dirname(self.__db_file) + if not os.path.exists(self.__db_file): + # Create subdirs if needed + if prefix: + os.makedirs(prefix, exist_ok=True) + + open(self.__db_file, 'a').close() + # Update DB file mode + os.chmod(self.__db_file, self.__db_file_mode) + + super().__init__() + + # Enable foreign keys for delete action to have cascade effect + @event.listens_for(self.engine, 'connect') + def set_sqlite_pragma(dbapi_connection, connection_record): + # Keep ON DELETE CASCADE behavior consistent + cursor = dbapi_connection.cursor() + cursor.execute('PRAGMA foreign_keys=ON') + cursor.close() + + def _build_connection_url(self): + return URL.create( + drivername='sqlite', + database=self.__db_file + ).render_as_string() + + @property + def _connection_kwargs(self): + timeout = runtime().get_option('storage/0/sqlite_conn_timeout') + return {'connect_args': {'timeout': timeout}} + + @contextlib.contextmanager + def db_read(self, *args, **kwargs): + with self.__db_lock.read_lock(): + with self.engine.begin() as conn: + yield conn + + @contextlib.contextmanager + def db_write(self, *args, **kwargs): + with self.__db_lock.write_lock(): + with self.engine.begin() as conn: + yield conn + + +class _PostgresConnector(_ConnectionStrategy): + def __init__(self): + super().__init__() + + def _build_connection_url(self): + host = runtime().get_option('storage/0/postgresql_host') + port = runtime().get_option('storage/0/postgresql_port') + db = runtime().get_option('storage/0/postgresql_db') + driver = runtime().get_option('storage/0/postgresql_driver') + user = os.getenv('RFM_POSTGRES_USER') + password = os.getenv('RFM_POSTGRES_PASSWORD') + if not (driver and host and port and db and user and password): + raise ReframeError( + 'Postgres connection info must be set in config and env') + + return URL.create( + drivername=f'postgresql+{driver}', + username=user, password=password, + host=host, port=port, database=db + ).render_as_string(hide_password=False) + + @property + def _connection_kwargs(self): + timeout = runtime().get_option('storage/0/postgresql_conn_timeout') + return {'connect_args': {'connect_timeout': timeout}} + + @property + def json_column_type(self): + return JSONB + + class StorageBackend: '''Abstract class that represents the results backend storage''' @@ -27,7 +173,9 @@ class StorageBackend: def create(cls, backend, *args, **kwargs): '''Factory method for creating storage backends''' if backend == 'sqlite': - return _SqliteStorage(*args, **kwargs) + return _SqlStorage(_SqliteConnector(), *args, **kwargs) + elif backend == 'postgresql': + return _SqlStorage(_PostgresConnector(), *args, **kwargs) else: raise ReframeError(f'no such storage backend: {backend}') @@ -74,38 +222,39 @@ def remove_sessions(self, selector: QuerySelector): ''' -class _SqliteStorage(StorageBackend): +class _SqlStorage(StorageBackend): SCHEMA_VERSION = '1.0' - def __init__(self): - self.__db_file = os.path.join( - osext.expandvars(runtime().get_option('storage/0/sqlite_db_file')) - ) - mode = runtime().get_option( - 'storage/0/sqlite_db_file_mode' - ) - if not isinstance(mode, int): - self.__db_file_mode = int(mode, base=8) - else: - self.__db_file_mode = mode - - self.__db_lock = osext.ReadWriteFileLock( - os.path.join(os.path.dirname(self.__db_file), '.db.lock'), - self.__db_file_mode - ) - - def _db_file(self): - prefix = os.path.dirname(self.__db_file) - if not os.path.exists(self.__db_file): - # Create subdirs if needed - if prefix: - os.makedirs(prefix, exist_ok=True) - - self._db_create() - - self._db_create_indexes() + def __init__(self, connector: _ConnectionStrategy): + self.__connector = connector + # Container for core table objects + self.__metadata = MetaData() + self._db_schema() + self._db_create() self._db_schema_check() - return self.__db_file + + def _db_schema(self): + self.__sessions_table = Table('sessions', self.__metadata, + Column('uuid', Text, primary_key=True), + Column('session_start_unix', Float), + Column('session_end_unix', Float), + Column( + 'json_blob', self.__connector.json_column_type), + Column('report_file', Text), + Index('index_sessions_time', 'session_start_unix')) + self.__testcases_table = Table('testcases', self.__metadata, + Column('name', Text), + Column('system', Text), + Column('partition', Text), + Column('environ', Text), + Column( + 'job_completion_time_unix', Float), + Column('session_uuid', Text, ForeignKey( + 'sessions.uuid', ondelete='CASCADE')), + Column('uuid', Text), + Index('index_testcases_time', 'job_completion_time_unix')) + self.__metadata_table = Table('metadata', self.__metadata, + Column('schema_version', Text)) def _db_matches(self, patt, item): if patt is None: @@ -124,76 +273,32 @@ def _db_filter_json(self, expr, item): return eval(expr, None, item) - def _db_connect(self, *args, **kwargs): - timeout = runtime().get_option('storage/0/sqlite_conn_timeout') - kwargs.setdefault('timeout', timeout) - with getprofiler().time_region('sqlite connect'): - return sqlite3.connect(*args, **kwargs) - - @contextlib.contextmanager - def _db_read(self, *args, **kwargs): - with self.__db_lock.read_lock(): - with self._db_connect(*args, **kwargs) as conn: - yield conn - - @contextlib.contextmanager - def _db_write(self, *args, **kwargs): - with self.__db_lock.write_lock(): - with self._db_connect(*args, **kwargs) as conn: - yield conn - def _db_create(self): clsname = type(self).__name__ getlogger().debug( - f'{clsname}: creating results database in {self.__db_file}...' + f'{clsname}: creating results database in {self.__connector.engine.url.database}...' ) - with self._db_write(self.__db_file) as conn: - conn.execute('CREATE TABLE IF NOT EXISTS sessions(' - 'uuid TEXT PRIMARY KEY, ' - 'session_start_unix REAL, ' - 'session_end_unix REAL, ' - 'json_blob TEXT, ' - 'report_file TEXT)') - conn.execute('CREATE TABLE IF NOT EXISTS testcases(' - 'name TEXT,' - 'system TEXT, ' - 'partition TEXT, ' - 'environ TEXT, ' - 'job_completion_time_unix REAL, ' - 'session_uuid TEXT, ' - 'uuid TEXT, ' - 'FOREIGN KEY(session_uuid) ' - 'REFERENCES sessions(uuid) ON DELETE CASCADE)') - - # Update DB file mode - os.chmod(self.__db_file, self.__db_file_mode) - - def _db_create_indexes(self): - clsname = type(self).__name__ - getlogger().debug(f'{clsname}: creating database indexes if needed') - with self._db_connect(self.__db_file) as conn: - conn.execute('CREATE INDEX IF NOT EXISTS index_testcases_time ' - 'on testcases(job_completion_time_unix)') - conn.execute('CREATE TABLE IF NOT EXISTS metadata(' - 'schema_version TEXT)') - conn.execute('CREATE INDEX IF NOT EXISTS index_sessions_time ' - 'on sessions(session_start_unix)') + self.__metadata.create_all(self.__connector.engine) def _db_schema_check(self): - with self._db_read(self.__db_file) as conn: + with self.__connector.db_read() as conn: results = conn.execute( - 'SELECT schema_version FROM metadata').fetchall() + self.__metadata_table.select() + ).fetchall() if not results: # DB is new, insert the schema version - with self._db_write(self.__db_file) as conn: - conn.execute('INSERT INTO metadata VALUES(:schema_version)', - {'schema_version': self.SCHEMA_VERSION}) + with self.__connector.db_write() as conn: + conn.execute( + self.__metadata_table.insert().values( + schema_version=self.SCHEMA_VERSION + ) + ) else: found_ver = results[0][0] if found_ver != self.SCHEMA_VERSION: raise ReframeError( - f'results DB in {self.__db_file!r} is ' + f'results DB in {self.__connector.engine.url.database!r} is ' 'of incompatible version: ' f'found {found_ver}, required: {self.SCHEMA_VERSION}' ) @@ -202,44 +307,45 @@ def _db_store_report(self, conn, report, report_file_path): session_start_unix = report['session_info']['time_start_unix'] session_end_unix = report['session_info']['time_end_unix'] session_uuid = report['session_info']['uuid'] + # Pass dict directly for JSONB + report_json = jsonext.dumps(report) + # Choose payload shape per backend + if self.__connector.json_column_type is JSONB: + json_payload = json.loads(report_json) + else: + json_payload = report_json + conn.execute( - 'INSERT INTO sessions VALUES(' - ':uuid, :session_start_unix, :session_end_unix, ' - ':json_blob, :report_file)', - { - 'uuid': session_uuid, - 'session_start_unix': session_start_unix, - 'session_end_unix': session_end_unix, - 'json_blob': jsonext.dumps(report), - 'report_file': report_file_path - } + self.__sessions_table.insert().values( + uuid=session_uuid, + session_start_unix=session_start_unix, + session_end_unix=session_end_unix, + json_blob=json_payload, + report_file=report_file_path + ) ) for run in report['runs']: for testcase in run['testcases']: sys, part = testcase['system'], testcase['partition'] conn.execute( - 'INSERT INTO testcases VALUES(' - ':name, :system, :partition, :environ, ' - ':job_completion_time_unix, ' - ':session_uuid, :uuid)', - { - 'name': testcase['name'], - 'system': sys, - 'partition': part, - 'environ': testcase['environ'], - 'job_completion_time_unix': testcase[ + self.__testcases_table.insert().values( + name=testcase['name'], + system=sys, + partition=part, + environ=testcase['environ'], + job_completion_time_unix=testcase[ 'job_completion_time_unix' ], - 'session_uuid': session_uuid, - 'uuid': testcase['uuid'] - } + session_uuid=session_uuid, + uuid=testcase['uuid'] + ) ) return session_uuid @time_function def store(self, report, report_file=None): - with self._db_write(self._db_file()) as conn: + with self.__connector.db_write() as conn: return self._db_store_report(conn, report, report_file) @time_function @@ -265,6 +371,9 @@ def _extract_sess_info(s): session_infos = {} sessions = {} for uuid, json_blob in results: + if not isinstance(json_blob, str): + # serialize into a json string + json_blob = json.dumps(json_blob) sessions.setdefault(uuid, json_blob) session_infos.setdefault(uuid, _extract_sess_info(json_blob)) @@ -289,17 +398,26 @@ def _decode_and_index_sessions(self, json_blobs): for sess in self._mass_json_decode(*json_blobs)} @time_function - def _fetch_testcases_raw(self, condition): + def _fetch_testcases_raw(self, condition: ClauseElement, order_by: ClauseElement = None): # Retrieve relevant session info and index it in Python - getprofiler().enter_region('sqlite session query') - with self._db_read(self._db_file()) as conn: - query = ('SELECT uuid, json_blob FROM sessions WHERE uuid IN ' - '(SELECT DISTINCT session_uuid FROM testcases ' - f'WHERE {condition})') + getprofiler().enter_region( + f'{self.__connector.engine.url.drivername} session query') + with self.__connector.db_read() as conn: + query = ( + select( + self.__sessions_table.c.uuid, + self.__sessions_table.c.json_blob + ) + .where( + self.__sessions_table.c.uuid.in_( + select(self.__testcases_table.c.session_uuid) + .distinct() + .where(condition) + ) + ) + ) getlogger().debug(query) - # Create SQLite function for filtering using name patterns - conn.create_function('REGEXP', 2, self._db_matches) results = conn.execute(query).fetchall() getprofiler().exit_region() @@ -310,11 +428,12 @@ def _fetch_testcases_raw(self, condition): ) # Extract the test case data by extracting their UUIDs - getprofiler().enter_region('sqlite testcase query') - with self._db_read(self._db_file()) as conn: - query = f'SELECT uuid FROM testcases WHERE {condition}' + getprofiler().enter_region( + f'{self.__connector.engine.url.drivername} testcase query') + with self.__connector.db_read() as conn: + query = select(self.__testcases_table.c.uuid).where( + condition).order_by(order_by) getlogger().debug(query) - conn.create_function('REGEXP', 2, self._db_matches) results = conn.execute(query).fetchall() getprofiler().exit_region() @@ -339,16 +458,24 @@ def _fetch_testcases_raw(self, condition): @time_function def _fetch_testcases_from_session(self, selector, name_patt=None, test_filter=None): - query = 'SELECT uuid, json_blob from sessions' + query = select( + self.__sessions_table.c.uuid, + self.__sessions_table.c.json_blob + ) if selector.by_session_uuid(): - query += f' WHERE uuid == "{selector.uuid}"' + query = query.where( + self.__sessions_table.c.uuid == selector.uuid + ) elif selector.by_time_period(): ts_start, ts_end = selector.time_period - query += (f' WHERE (session_start_unix >= {ts_start} AND ' - f'session_start_unix < {ts_end})') + query = query.where( + self.__sessions_table.c.session_start_unix >= ts_start, + self.__sessions_table.c.session_start_unix < ts_end + ) - getprofiler().enter_region('sqlite session query') - with self._db_read(self._db_file()) as conn: + getprofiler().enter_region( + f'{self.__connector.engine.url.drivername} session query') + with self.__connector.db_read() as conn: getlogger().debug(query) results = conn.execute(query).fetchall() @@ -370,13 +497,16 @@ def _fetch_testcases_from_session(self, selector, name_patt=None, @time_function def _fetch_testcases_time_period(self, ts_start, ts_end, name_patt=None, test_filter=None): - expr = (f'job_completion_time_unix >= {ts_start} AND ' - f'job_completion_time_unix < {ts_end}') + expr = [ + self.__testcases_table.c.job_completion_time_unix >= ts_start, + self.__testcases_table.c.job_completion_time_unix < ts_end + ] if name_patt: - expr += f' AND name REGEXP "{name_patt}"' + expr.append(self.__testcases_table.c.name.regexp_match(name_patt)) testcases = self._fetch_testcases_raw( - f'({expr}) ORDER BY job_completion_time_unix' + and_(*expr), + self.__testcases_table.c.job_completion_time_unix ) filt_fn = functools.partial(self._db_filter_json, test_filter) return [*filter(filt_fn, testcases)] @@ -395,16 +525,24 @@ def fetch_testcases(self, selector: QuerySelector, @time_function def fetch_sessions(self, selector: QuerySelector, decode=True): - query = 'SELECT uuid, json_blob FROM sessions' + query = select( + self.__sessions_table.c.uuid, + self.__sessions_table.c.json_blob + ) if selector.by_time_period(): ts_start, ts_end = selector.time_period - query += (f' WHERE (session_start_unix >= {ts_start} AND ' - f'session_start_unix < {ts_end})') + query = query.where( + self.__sessions_table.c.session_start_unix >= ts_start, + self.__sessions_table.c.session_start_unix < ts_end + ) elif selector.by_session_uuid(): - query += f' WHERE uuid == "{selector.uuid}"' + query = query.where( + self.__sessions_table.c.uuid == selector.uuid + ) - getprofiler().enter_region('sqlite session query') - with self._db_read(self._db_file()) as conn: + getprofiler().enter_region( + f'{self.__connector.engine.url.drivername} session query') + with self.__connector.db_read() as conn: getlogger().debug(query) results = conn.execute(query).fetchall() @@ -421,15 +559,18 @@ def fetch_sessions(self, selector: QuerySelector, decode=True): def _do_remove(self, conn, uuids): '''Remove sessions''' - # Enable foreign keys for delete action to have cascade effect - conn.execute('PRAGMA foreign_keys = ON') - uuids_sql = ','.join(f'"{uuid}"' for uuid in uuids) - query = f'DELETE FROM sessions WHERE uuid IN ({uuids_sql})' + query = ( + delete(self.__sessions_table) + .where(self.__sessions_table.c.uuid.in_(uuids)) + ) getlogger().debug(query) conn.execute(query).fetchall() # Retrieve the uuids that have been removed - query = f'SELECT uuid FROM sessions WHERE uuid IN ({uuids_sql})' + query = ( + select(self.__sessions_table.c.uuid) + .where(self.__sessions_table.c.uuid.in_(uuids)) + ) getlogger().debug(query) results = conn.execute(query).fetchall() not_removed = {rec[0] for rec in results} @@ -438,11 +579,11 @@ def _do_remove(self, conn, uuids): def _do_remove2(self, conn, uuids): '''Remove sessions using the RETURNING keyword''' - # Enable foreign keys for delete action to have cascade effect - conn.execute('PRAGMA foreign_keys = ON') - uuids_sql = ','.join(f'"{uuid}"' for uuid in uuids) - query = (f'DELETE FROM sessions WHERE uuid IN ({uuids_sql}) ' - 'RETURNING uuid') + query = ( + delete(self.__sessions_table) + .where(self.__sessions_table.c.uuid.in_(uuids)) + .returning(self.__sessions_table.c.uuid) + ) getlogger().debug(query) results = conn.execute(query).fetchall() return [rec[0] for rec in results] @@ -455,8 +596,8 @@ def remove_sessions(self, selector: QuerySelector): uuids = [sess['session_info']['uuid'] for sess in self.fetch_sessions(selector)] - with self._db_write(self._db_file()) as conn: - if sqlite3.sqlite_version_info >= (3, 35, 0): + with self.__connector.db_write() as conn: + if getattr(conn.dialect, 'delete_returning', False): return self._do_remove2(conn, uuids) else: return self._do_remove(conn, uuids) diff --git a/reframe/schemas/config.json b/reframe/schemas/config.json index c731db1bd..4cbeaa79e 100644 --- a/reframe/schemas/config.json +++ b/reframe/schemas/config.json @@ -331,7 +331,7 @@ "prepare_cmds": { "type": "array", "items": {"type": "string"} - }, + }, "processor": {"$ref": "#/defs/processor_info"}, "devices": {"$ref": "#/defs/devices"}, "features": { @@ -560,6 +560,11 @@ "sqlite_conn_timeout": {"type": "number"}, "sqlite_db_file": {"type": "string"}, "sqlite_db_file_mode": {"type": "string"}, + "postgresql_driver": {"type": "string"}, + "postgresql_host": {"type": "string"}, + "postgresql_port": {"type": "number"}, + "postgresql_db": {"type": "string"}, + "postgresql_conn_timeout": {"type": "number"}, "target_systems": {"$ref": "#/defs/system_ref"} } } @@ -654,6 +659,8 @@ "storage/sqlite_conn_timeout": 60, "storage/sqlite_db_file": "${HOME}/.reframe/reports/results.db", "storage/sqlite_db_file_mode": "644", + "storage/postgresql_conn_timeout": 60, + "storage/postgresql_driver": "psycopg2", "storage/target_systems": ["*"], "systems/descr": "", "systems/max_local_jobs": 8, diff --git a/requirements.txt b/requirements.txt index 192cdad95..ce634bba1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,6 @@ setuptools==80.9.0; python_version >= '3.9' tabulate==0.8.10; python_version == '3.6' tabulate==0.9.0; python_version >= '3.7' wcwidth==0.2.14 +sqlalchemy==2.0.41 +psycopg2-binary==2.9.8 #+pygelf%pygelf==0.4.0