diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..b6b309e3 --- /dev/null +++ b/.flake8 @@ -0,0 +1,14 @@ +# This file is part of the EESSI filesystem layer, +# see https://github.com/EESSI/filesystem-layer +# +# author: Thomas Roeblitz (@trz42) +# +# license: GPLv2 +# + +[flake8] +max-line-length = 120 + +# ignore "Black would make changes" produced by flake8-black +# see also https://github.com/houndci/hound/issues/1769 +extend-ignore = BLK100 diff --git a/.github/workflows/check-flake8.yml b/.github/workflows/check-flake8.yml new file mode 100644 index 00000000..f0ebe250 --- /dev/null +++ b/.github/workflows/check-flake8.yml @@ -0,0 +1,36 @@ +# This file is part of the EESSI filesystem layer, +# see https://github.com/EESSI/filesystem-layer +# +# author: Thomas Roeblitz (@trz42) +# +# license: GPLv2 +# + +name: Run tests +on: [push, pull_request] +# Declare default permissions as read only. +permissions: read-all +jobs: + test: + runs-on: ubuntu-22.04 + strategy: + matrix: + python: [3.7, 3.8, 3.9, '3.10', '3.11', '3.12'] + fail-fast: false + steps: + - name: checkout + uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # v3.1.0 + + - name: set up Python + uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # v4.3.0 + with: + python-version: ${{matrix.python}} + + - name: Install required Python packages + pytest + flake8 + run: | + python -m pip install --upgrade pip + python -m pip install --upgrade flake8 + + - name: Run flake8 to verify PEP8-compliance of Python code + run: | + flake8 diff --git a/scripts/automated_ingestion/automated_ingestion.cfg.example b/scripts/automated_ingestion/automated_ingestion.cfg.example index 68df3e4e..18009d88 100644 --- a/scripts/automated_ingestion/automated_ingestion.cfg.example +++ b/scripts/automated_ingestion/automated_ingestion.cfg.example @@ -63,13 +63,61 @@ pr_body = A new tarball has been staged for {pr_url}. ``` + +
+ Overview of tarball contents + + {tar_overview} + +
+# Method for creating staging PRs: +# - 'individual': create one PR per tarball (old method) +# - 'grouped': group tarballs by link2pr and create one PR per group (new method) +staging_pr_method = individual + +# Template for individual tarball PRs +individual_pr_body = A new tarball has been staged for {pr_url}. + Please review the contents of this tarball carefully. + Merging this PR will lead to automatic ingestion of the tarball to the repository {cvmfs_repo}. + +
+ Metadata of tarball + + ``` + {metadata} + ``` + +
+ +
+ Overview of tarball contents + + {tar_overview} + +
+ +# Template for grouped tarball PRs +grouped_pr_body = A group of tarballs has been staged for {pr_url}. + Please review the contents of these tarballs carefully. + Merging this PR will lead to automatic ingestion of the approved tarballs to the repository {cvmfs_repo}. + Unchecked tarballs will be marked as rejected. + + {tarballs} +
Overview of tarball contents {tar_overview}
+ + {metadata} + +# Template for payload overview +task_summary_payload_template = + {payload_overview} + [slack] ingestion_notification = yes diff --git a/scripts/automated_ingestion/automated_ingestion.py b/scripts/automated_ingestion/automated_ingestion.py index 92dac552..974f8497 100755 --- a/scripts/automated_ingestion/automated_ingestion.py +++ b/scripts/automated_ingestion/automated_ingestion.py @@ -1,12 +1,15 @@ #!/usr/bin/env python3 -from eessitarball import EessiTarball -from pid.decorator import pidfile +from eessitarball import EessiTarball, EessiTarballGroup +from eessi_data_object import EESSIDataAndSignatureObject +from eessi_task import EESSITask, TaskState +from eessi_task_description import EESSITaskDescription +from s3_bucket import EESSIS3Bucket +from pid.decorator import pidfile # noqa: F401 from pid import PidFileError +from utils import log_function_entry_exit, log_message, LoggingScope, set_logging_scopes import argparse -import boto3 -import botocore import configparser import github import json @@ -14,6 +17,8 @@ import os import pid import sys +from pathlib import Path +from typing import List REQUIRED_CONFIG = { 'secrets': ['aws_secret_access_key', 'aws_access_key_id', 'github_pat'], @@ -33,83 +38,350 @@ def error(msg, code=1): """Print an error and exit.""" - logging.error(msg) + log_message(LoggingScope.ERROR, 'ERROR', msg) sys.exit(code) -def find_tarballs(s3, bucket, extension='.tar.gz', metadata_extension='.meta.txt'): - """Return a list of all tarballs in an S3 bucket that have a metadata file with the given extension (and same filename).""" +def find_tarballs(s3_bucket, extension='.tar.gz', metadata_extension='.meta.txt'): + """ + Return a list of all tarballs in an S3 bucket that have a metadata file with + the given extension (and same filename). + """ # TODO: list_objects_v2 only returns up to 1000 objects - s3_objects = s3.list_objects_v2(Bucket=bucket).get('Contents', []) + s3_objects = s3_bucket.list_objects_v2().get('Contents', []) files = [obj['Key'] for obj in s3_objects] tarballs = [ file for file in files - if file.endswith(extension) - and file + metadata_extension in files + if file.endswith(extension) and file + metadata_extension in files ] return tarballs +@log_function_entry_exit() +def find_tarball_groups(s3_bucket, config, extension='.tar.gz', metadata_extension='.meta.txt'): + """Return a dictionary of tarball groups, keyed by (repo, pr_number).""" + tarballs = find_tarballs(s3_bucket, extension, metadata_extension) + groups = {} + + for tarball in tarballs: + # Download metadata to get link2pr info + metadata_file = tarball + metadata_extension + local_metadata = os.path.join(config['paths']['download_dir'], os.path.basename(metadata_file)) + + try: + s3_bucket.download_file(metadata_file, local_metadata) + with open(local_metadata, 'r') as meta: + metadata = json.load(meta) + repo = metadata['link2pr']['repo'] + pr = metadata['link2pr']['pr'] + group_key = (repo, pr) + + if group_key not in groups: + groups[group_key] = [] + groups[group_key].append(tarball) + except Exception as err: + log_message(LoggingScope.ERROR, 'ERROR', "Failed to process metadata for %s: %s", tarball, err) + continue + finally: + # Clean up downloaded metadata file + if os.path.exists(local_metadata): + os.remove(local_metadata) + + return groups + + +@log_function_entry_exit() def parse_config(path): """Parse the configuration file.""" config = configparser.ConfigParser() try: config.read(path) - except: - error(f'Unable to read configuration file {path}!') + except Exception as err: + error(f'Unable to read configuration file {path}!\nException: {err}') # Check if all required configuration parameters/sections can be found. for section in REQUIRED_CONFIG.keys(): - if not section in config: + if section not in config: error(f'Missing section "{section}" in configuration file {path}.') for item in REQUIRED_CONFIG[section]: - if not item in config[section]: + if item not in config[section]: error(f'Missing configuration item "{item}" in section "{section}" of configuration file {path}.') + + # Validate staging_pr_method + staging_method = config['github'].get('staging_pr_method', 'individual') + if staging_method not in ['individual', 'grouped']: + error( + f'Invalid staging_pr_method: "{staging_method}" in configuration file {path}. ' + 'Must be either "individual" or "grouped".' + ) + + # Validate PR body templates + if staging_method == 'individual' and 'individual_pr_body' not in config['github']: + error(f'Missing "individual_pr_body" in configuration file {path}.') + if staging_method == 'grouped' and 'grouped_pr_body' not in config['github']: + error(f'Missing "grouped_pr_body" in configuration file {path}.') + return config +@log_function_entry_exit() def parse_args(): """Parse the command-line arguments.""" parser = argparse.ArgumentParser() + + # Logging options + logging_group = parser.add_argument_group('Logging options') + logging_group.add_argument('--log-file', + help='Path to log file (overrides config file setting)') + logging_group.add_argument('--console-level', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + help='Logging level for console output (overrides config file setting)') + logging_group.add_argument('--file-level', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + help='Logging level for file output (overrides config file setting)') + logging_group.add_argument('--quiet', + action='store_true', + help='Suppress console output (overrides all other console settings)') + logging_group.add_argument('--log-scopes', + help='Comma-separated list of logging scopes using +/- syntax. ' + 'Examples: "+FUNC_ENTRY_EXIT" (enable only function entry/exit), ' + '"+ALL,-FUNC_ENTRY_EXIT" (enable all except function entry/exit), ' + '"+FUNC_ENTRY_EXIT,-EXAMPLE_SCOPE" (enable function entry/exit but disable example)') + + # Existing arguments parser.add_argument('-c', '--config', type=str, help='path to configuration file', default='automated_ingestion.cfg', dest='config') parser.add_argument('-d', '--debug', help='enable debug mode', action='store_true', dest='debug') - parser.add_argument('-l', '--list', help='only list available tarballs', action='store_true', dest='list_only') - args = parser.parse_args() - return args + parser.add_argument('-l', '--list', help='only list available tarballs or tasks', action='store_true', + dest='list_only') + parser.add_argument('--task-based', help='use task-based ingestion instead of tarball-based. ' + 'Optionally specify comma-separated list of extensions (default: .task)', + nargs='?', const='.task', default=False) + + return parser.parse_args() + + +@log_function_entry_exit() +def setup_logging(config, args): + """ + Configure logging based on configuration file and command line arguments. + Command line arguments take precedence over config file settings. + + Args: + config: Configuration dictionary + args: Parsed command line arguments + """ + # Get settings from config file + log_file = config['logging'].get('filename') + config_console_level = LOG_LEVELS.get(config['logging'].get('level', 'INFO').upper(), logging.INFO) + config_file_level = LOG_LEVELS.get(config['logging'].get('file_level', 'DEBUG').upper(), logging.DEBUG) + + # Override with command line arguments if provided + log_file = args.log_file if args.log_file else log_file + console_level = getattr(logging, args.console_level) if args.console_level else config_console_level + file_level = getattr(logging, args.file_level) if args.file_level else config_file_level + + # Debug mode overrides console level + if args.debug: + console_level = logging.DEBUG + + # Set up logging scopes + if args.log_scopes: + set_logging_scopes(args.log_scopes) + log_message(LoggingScope.DEBUG, 'DEBUG', "Enabled logging scopes: %s", args.log_scopes) + + # Create logger + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) # Set root logger to lowest level + + # Create formatters + console_formatter = logging.Formatter('%(levelname)-8s: %(message)s') + file_formatter = logging.Formatter('%(asctime)s - %(levelname)-8s: %(message)s') + + # Console handler (only if not quiet) + if not args.quiet: + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(console_level) + console_handler.setFormatter(console_formatter) + logger.addHandler(console_handler) + + # File handler (if log file is specified) + if log_file: + # Ensure log directory exists + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(file_level) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + return logger @pid.decorator.pidfile('automated_ingestion.pid') +@log_function_entry_exit() def main(): """Main function.""" args = parse_args() config = parse_config(args.config) - log_file = config['logging'].get('filename', None) - log_format = config['logging'].get('format', '%(levelname)s:%(message)s') - log_level = LOG_LEVELS.get(config['logging'].get('level', 'INFO').upper(), logging.WARN) - log_level = logging.DEBUG if args.debug else log_level - logging.basicConfig(filename=log_file, format=log_format, level=log_level) + setup_logging(config, args) + # TODO: check configuration: secrets, paths, permissions on dirs, etc gh_pat = config['secrets']['github_pat'] gh_staging_repo = github.Github(gh_pat).get_repo(config['github']['staging_repo']) - s3 = boto3.client( - 's3', - aws_access_key_id=config['secrets']['aws_access_key_id'], - aws_secret_access_key=config['secrets']['aws_secret_access_key'], - ) buckets = json.loads(config['aws']['staging_buckets']) for bucket, cvmfs_repo in buckets.items(): - tarballs = find_tarballs(s3, bucket) - if args.list_only: - for num, tarball in enumerate(tarballs): - print(f'[{bucket}] {num}: {tarball}') + # Create our custom S3 bucket for this bucket + s3_bucket = EESSIS3Bucket(config, bucket) + + if args.task_based: + # Task-based listing + extensions = args.task_based.split(',') + tasks = find_deployment_tasks(s3_bucket, extensions) + if args.list_only: + log_message(LoggingScope.GROUP_OPS, 'INFO', "#tasks: %d", len(tasks)) + for num, task in enumerate(tasks): + log_message(LoggingScope.GROUP_OPS, 'INFO', "[%s] %d: %s", bucket, num, task) + else: + # Process each task file + for task_path in tasks: + log_message(LoggingScope.GROUP_OPS, 'INFO', "Processing task: %s", task_path) + try: + # Create EESSITask for the task file + try: + task = EESSITask( + EESSITaskDescription(EESSIDataAndSignatureObject(config, task_path, s3_bucket)), + config, cvmfs_repo, gh_staging_repo + ) + + except Exception as err: + log_message(LoggingScope.ERROR, 'ERROR', "Failed to create EESSITask for task %s: %s", + task_path, str(err)) + continue + + log_message(LoggingScope.GROUP_OPS, 'INFO', "Task: %s", task) + + previous_state = None + current_state = task.determine_state() + log_message(LoggingScope.GROUP_OPS, 'INFO', "Task '%s' is in state '%s'", + task_path, current_state.name) + while (current_state is not None and + current_state != TaskState.DONE and + previous_state != current_state): + previous_state = current_state + log_message(LoggingScope.GROUP_OPS, 'INFO', + "Task '%s': BEFORE handle(): previous state = '%s', current state = '%s'", + task_path, previous_state.name, current_state.name) + current_state = task.handle() + log_message(LoggingScope.GROUP_OPS, 'INFO', + "Task '%s': AFTER handle(): previous state = '%s', current state = '%s'", + task_path, previous_state.name, current_state.name) + + # # TODO: update the information shown below (what makes sense to show?) + # # Log information about the task + # task_object = task.description.task_object + # log_message(LoggingScope.GROUP_OPS, 'INFO', "Task file: %s", task_object.local_file_path) + # log_message(LoggingScope.GROUP_OPS, 'INFO', "Signature file: %s", task_object.local_sig_path) + # log_message(LoggingScope.GROUP_OPS, 'INFO', "Signature verified: %s", + # task.description.signature_verified) + + # # Log the ETags of the downloaded task file + # file_etag, sig_etag = task.description.task_object.get_etags() + # log_message(LoggingScope.GROUP_OPS, 'INFO', "Task file %s has ETag: %s", task_path, file_etag) + # log_message(LoggingScope.GROUP_OPS, 'INFO', "Task signature %s has ETag: %s", + # task.description.task_object.remote_sig_path, sig_etag) + + # # TODO: Process the task file contents + # # This would involve reading the task file, parsing its contents, + # # and performing the required actions based on the task type + # log_message(LoggingScope.GROUP_OPS, 'INFO', "TODO: Processing task file: %s", task_path) + # task.handle() + + except Exception as err: + log_message(LoggingScope.ERROR, 'ERROR', "Failed to process task %s: %s", task_path, str(err)) + continue else: - for tarball in tarballs: - tar = EessiTarball(tarball, config, gh_staging_repo, s3, bucket, cvmfs_repo) - tar.run_handler() + # Original tarball-based processing + if config['github'].get('staging_pr_method', 'individual') == 'grouped': + # use new grouped PR method + tarball_groups = find_tarball_groups(s3_bucket, config) + if args.list_only: + log_message(LoggingScope.GROUP_OPS, 'INFO', "#tarball_groups: %d", len(tarball_groups)) + for (repo, pr_id), tarballs in tarball_groups.items(): + log_message(LoggingScope.GROUP_OPS, 'INFO', " %s#%s: #tarballs %d", repo, pr_id, len(tarballs)) + else: + for (repo, pr_id), tarballs in tarball_groups.items(): + if tarballs: + # Create a group for these tarballs + group = EessiTarballGroup(tarballs[0], config, gh_staging_repo, s3_bucket, cvmfs_repo) + log_message(LoggingScope.GROUP_OPS, 'INFO', "group created\n%s", + group.to_string(oneline=True)) + group.process_group(tarballs) + else: + # use old individual PR method + tarballs = find_tarballs(s3_bucket) + if args.list_only: + for num, tarball in enumerate(tarballs): + log_message(LoggingScope.GROUP_OPS, 'INFO', "[%s] %d: %s", bucket, num, tarball) + else: + for tarball in tarballs: + tar = EessiTarball(tarball, config, gh_staging_repo, s3_bucket, cvmfs_repo) + tar.run_handler() + + +@log_function_entry_exit() +def find_deployment_tasks(s3_bucket, extensions: List[str] = None) -> List[str]: + """ + Return a list of all task files in an S3 bucket with the given extensions, + but only if a corresponding payload file exists (same name without extension). + + Args: + s3_bucket: EESSIS3Bucket instance + extensions: List of file extensions to look for (default: ['.task']) + + Returns: + List of task filenames found in the bucket that have a corresponding payload + """ + if extensions is None: + extensions = ['.task'] + + files = [] + continuation_token = None + + while True: + # List objects with pagination + if continuation_token: + response = s3_bucket.list_objects_v2( + ContinuationToken=continuation_token + ) + else: + response = s3_bucket.list_objects_v2() + + # Add files from this page + files.extend([obj['Key'] for obj in response.get('Contents', [])]) + + # Check if there are more pages + if response.get('IsTruncated'): + continuation_token = response.get('NextContinuationToken') + else: + break + + # Create a set of all files for faster lookup + file_set = set(files) + + # Return only task files that have a corresponding payload + result = [] + for file in files: + for ext in extensions: + if file.endswith(ext) and file[:-len(ext)] in file_set: + result.append(file) + break # Found a matching extension, no need to check others + + return result if __name__ == '__main__': diff --git a/scripts/automated_ingestion/eessi_data_object.py b/scripts/automated_ingestion/eessi_data_object.py new file mode 100644 index 00000000..c7adc05b --- /dev/null +++ b/scripts/automated_ingestion/eessi_data_object.py @@ -0,0 +1,334 @@ +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import configparser + +from utils import log_function_entry_exit, log_message, LoggingScope +from remote_storage import RemoteStorageClient, DownloadMode + + +@dataclass +class EESSIDataAndSignatureObject: + """Class representing an EESSI data file and its signature in remote storage and locally.""" + + # Configuration + config: configparser.ConfigParser + + # Remote paths + remote_file_path: str # Path to data file in remote storage + remote_sig_path: str # Path to signature file in remote storage + + # Local paths + local_file_path: Path # Path to local data file + local_sig_path: Path # Path to local signature file + + # Remote storage client + remote_client: RemoteStorageClient + + @log_function_entry_exit() + def __init__(self, config: configparser.ConfigParser, remote_file_path: str, remote_client: RemoteStorageClient): + """ + Initialize an EESSI data and signature object handler. + + Args: + config: Configuration object containing remote storage and local directory information + remote_file_path: Path to data file in remote storage + remote_client: Remote storage client implementing the RemoteStorageClient protocol + """ + self.config = config + self.remote_file_path = remote_file_path + sig_ext = config['signatures']['signature_file_extension'] + self.remote_sig_path = remote_file_path + sig_ext + + # Set up local paths + local_dir = Path(config['paths']['download_dir']) + # Use the full remote path structure, removing any leading slashes + remote_path = remote_file_path.lstrip('/') + self.local_file_path = local_dir.joinpath(remote_path) + self.local_sig_path = local_dir.joinpath(remote_path + sig_ext) + self.remote_client = remote_client + + log_message(LoggingScope.DEBUG, 'DEBUG', "Initialized EESSIDataAndSignatureObject for %s", remote_file_path) + log_message(LoggingScope.DEBUG, 'DEBUG', "Local file path: %s", self.local_file_path) + log_message(LoggingScope.DEBUG, 'DEBUG', "Local signature path: %s", self.local_sig_path) + + def _get_etag_file_path(self, local_path: Path) -> Path: + """Get the path to the .etag file for a given local file.""" + return local_path.with_suffix('.etag') + + def _get_local_etag(self, local_path: Path) -> Optional[str]: + """Get the ETag of a local file from its .etag file.""" + etag_path = self._get_etag_file_path(local_path) + if etag_path.exists(): + try: + with open(etag_path, 'r') as f: + return f.read().strip() + except Exception as err: + log_message(LoggingScope.DEBUG, 'WARNING', "Failed to read ETag file %s: %s", etag_path, str(err)) + return None + return None + + def get_etags(self) -> tuple[Optional[str], Optional[str]]: + """ + Get the ETags of both the data file and its signature. + + Returns: + Tuple containing (data_file_etag, signature_file_etag) + """ + return ( + self._get_local_etag(self.local_file_path), + self._get_local_etag(self.local_sig_path) + ) + + @log_function_entry_exit() + def verify_signature(self) -> bool: + """ + Verify the signature of the data file using the corresponding signature file. + + Returns: + bool: True if the signature is valid or if signatures are not required, False otherwise + """ + # Check if signature file exists + if not self.local_sig_path.exists(): + log_message(LoggingScope.VERIFICATION, 'WARNING', "Signature file %s is missing", + self.local_sig_path) + + # If signatures are required, return failure + if self.config['signatures'].getboolean('signatures_required', True): + log_message(LoggingScope.ERROR, 'ERROR', "Signature file %s is missing and signatures are required", + self.local_sig_path) + return False + else: + log_message(LoggingScope.VERIFICATION, 'INFO', + "Signature file %s is missing, but signatures are not required", + self.local_sig_path) + return True + + # If signatures are provided, we should always verify them, regardless of the signatures_required setting + verify_runenv = self.config['signatures']['signature_verification_runenv'].split() + verify_script = self.config['signatures']['signature_verification_script'] + allowed_signers_file = self.config['signatures']['allowed_signers_file'] + + # Check if verification tools exist + if not Path(verify_script).exists(): + log_message(LoggingScope.ERROR, 'ERROR', + "Unable to verify signature: verification script %s does not exist", verify_script) + return False + + if not Path(allowed_signers_file).exists(): + log_message(LoggingScope.ERROR, 'ERROR', + "Unable to verify signature: allowed signers file %s does not exist", allowed_signers_file) + return False + + # Run the verification command with named parameters + cmd = verify_runenv + [ + verify_script, + '--verify', + '--allowed-signers-file', allowed_signers_file, + '--file', str(self.local_file_path), + '--signature-file', str(self.local_sig_path) + ] + log_message(LoggingScope.VERIFICATION, 'INFO', "Running command: %s", ' '.join(cmd)) + + try: + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode == 0: + log_message(LoggingScope.VERIFICATION, 'INFO', + "Successfully verified signature for %s", self.local_file_path) + log_message(LoggingScope.VERIFICATION, 'DEBUG', " stdout: %s", result.stdout) + log_message(LoggingScope.VERIFICATION, 'DEBUG', " stderr: %s", result.stderr) + return True + else: + log_message(LoggingScope.ERROR, 'ERROR', + "Signature verification failed for %s", self.local_file_path) + log_message(LoggingScope.ERROR, 'ERROR', " stdout: %s", result.stdout) + log_message(LoggingScope.ERROR, 'ERROR', " stderr: %s", result.stderr) + return False + except Exception as err: + log_message(LoggingScope.ERROR, 'ERROR', + "Error during signature verification for %s: %s", + self.local_file_path, str(err)) + return False + + @log_function_entry_exit() + def download(self, mode: DownloadMode = DownloadMode.CHECK_REMOTE) -> bool: + """ + Download data file and signature based on the specified mode. + + Args: + mode: Download mode to use + + Returns: + True if files were downloaded, False otherwise + """ + # If mode is FORCE, we always download regardless of local or remote state + if mode == DownloadMode.FORCE: + should_download = True + log_message(LoggingScope.DOWNLOAD, 'INFO', "Forcing download of %s", self.remote_file_path) + # For CHECK_REMOTE mode, check if we can optimize + elif mode == DownloadMode.CHECK_REMOTE: + # Optimization: Check if local files exist first + local_files_exist = ( + self.local_file_path.exists() and + self.local_sig_path.exists() + ) + + # If files don't exist locally, we can skip ETag checks + if not local_files_exist: + log_message(LoggingScope.DOWNLOAD, 'INFO', + "Local files missing, skipping ETag checks and downloading %s", + self.remote_file_path) + should_download = True + else: + # First check if we have local ETags + try: + local_file_etag = self._get_local_etag(self.local_file_path) + local_sig_etag = self._get_local_etag(self.local_sig_path) + + if local_file_etag: + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Local file ETag: %s", local_file_etag) + else: + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "No local file ETag found") + if local_sig_etag: + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Local signature ETag: %s", local_sig_etag) + else: + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "No local signature ETag found") + + # If we don't have local ETags, we need to download + if not local_file_etag or not local_sig_etag: + should_download = True + log_message(LoggingScope.DOWNLOAD, 'INFO', "Missing local ETags, downloading %s", + self.remote_file_path) + else: + # Get remote ETags and compare + remote_file_etag = self.remote_client.get_metadata(self.remote_file_path)['ETag'] + remote_sig_etag = self.remote_client.get_metadata(self.remote_sig_path)['ETag'] + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Remote file ETag: %s", remote_file_etag) + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Remote signature ETag: %s", remote_sig_etag) + + should_download = ( + remote_file_etag != local_file_etag or + remote_sig_etag != local_sig_etag + ) + if should_download: + if remote_file_etag != local_file_etag: + log_message(LoggingScope.DOWNLOAD, 'INFO', "File ETag changed from %s to %s", + local_file_etag, remote_file_etag) + if remote_sig_etag != local_sig_etag: + log_message(LoggingScope.DOWNLOAD, 'INFO', "Signature ETag changed from %s to %s", + local_sig_etag, remote_sig_etag) + log_message(LoggingScope.DOWNLOAD, 'INFO', "Remote files have changed, downloading %s", + self.remote_file_path) + else: + log_message(LoggingScope.DOWNLOAD, 'INFO', + "Remote files unchanged, skipping download of %s", + self.remote_file_path) + except Exception as etag_err: + # If we get any error with ETags, we'll just download the files + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Error handling ETags, will download files: %s", + str(etag_err)) + should_download = True + else: # CHECK_LOCAL + should_download = ( + not self.local_file_path.exists() or + not self.local_sig_path.exists() + ) + if should_download: + if not self.local_file_path.exists(): + log_message(LoggingScope.DOWNLOAD, 'INFO', "Local file missing: %s", self.local_file_path) + if not self.local_sig_path.exists(): + log_message(LoggingScope.DOWNLOAD, 'INFO', "Local signature missing: %s", self.local_sig_path) + log_message(LoggingScope.DOWNLOAD, 'INFO', "Local files missing, downloading %s", + self.remote_file_path) + else: + log_message(LoggingScope.DOWNLOAD, 'INFO', "Local files exist, skipping download of %s", + self.remote_file_path) + + if not should_download: + return False + + # Ensure local directory exists + self.local_file_path.parent.mkdir(parents=True, exist_ok=True) + + # Download files + try: + # Download the main file first + self.remote_client.download(self.remote_file_path, str(self.local_file_path)) + + # Get and log the ETag of the downloaded file + try: + file_etag = self._get_local_etag(self.local_file_path) + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Downloaded %s with ETag: %s", + self.remote_file_path, file_etag) + except Exception as etag_err: + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Error getting ETag for %s: %s", + self.remote_file_path, str(etag_err)) + + # Try to download the signature file + try: + self.remote_client.download(self.remote_sig_path, str(self.local_sig_path)) + try: + sig_etag = self._get_local_etag(self.local_sig_path) + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Downloaded %s with ETag: %s", + self.remote_sig_path, sig_etag) + except Exception as etag_err: + log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Error getting ETag for %s: %s", + self.remote_sig_path, str(etag_err)) + log_message(LoggingScope.DOWNLOAD, 'INFO', "Successfully downloaded %s and its signature", + self.remote_file_path) + except Exception as sig_err: + # Check if signatures are required + if self.config['signatures'].getboolean('signatures_required', True): + # If signatures are required, clean up everything since we can't proceed + if self.local_file_path.exists(): + self.local_file_path.unlink() + # Clean up etag files regardless of whether their data files exist + file_etag_path = self._get_etag_file_path(self.local_file_path) + if file_etag_path.exists(): + file_etag_path.unlink() + sig_etag_path = self._get_etag_file_path(self.local_sig_path) + if sig_etag_path.exists(): + sig_etag_path.unlink() + log_message(LoggingScope.ERROR, 'ERROR', "Failed to download required signature for %s: %s", + self.remote_file_path, str(sig_err)) + raise + else: + # If signatures are optional, just clean up any partial signature files + if self.local_sig_path.exists(): + self.local_sig_path.unlink() + sig_etag_path = self._get_etag_file_path(self.local_sig_path) + if sig_etag_path.exists(): + sig_etag_path.unlink() + log_message(LoggingScope.DOWNLOAD, 'WARNING', "Failed to download optional signature for %s: %s", + self.remote_file_path, str(sig_err)) + log_message(LoggingScope.DOWNLOAD, 'INFO', "Successfully downloaded %s (signature optional)", + self.remote_file_path) + + return True + except Exception as err: + # This catch block is only for errors in the main file download + # Clean up partially downloaded files and their etags + if self.local_file_path.exists(): + self.local_file_path.unlink() + if self.local_sig_path.exists(): + self.local_sig_path.unlink() + # Clean up etag files regardless of whether their data files exist + file_etag_path = self._get_etag_file_path(self.local_file_path) + if file_etag_path.exists(): + file_etag_path.unlink() + sig_etag_path = self._get_etag_file_path(self.local_sig_path) + if sig_etag_path.exists(): + sig_etag_path.unlink() + log_message(LoggingScope.ERROR, 'ERROR', "Failed to download %s: %s", self.remote_file_path, str(err)) + raise + + @log_function_entry_exit() + def get_url(self) -> str: + """Get the URL of the data file.""" + return f"https://{self.remote_client.bucket}.s3.amazonaws.com/{self.remote_file_path}" + + def __str__(self) -> str: + """Return a string representation of the EESSI data and signature object.""" + return f"EESSIDataAndSignatureObject({self.remote_file_path})" diff --git a/scripts/automated_ingestion/eessi_task.py b/scripts/automated_ingestion/eessi_task.py new file mode 100644 index 00000000..bd863946 --- /dev/null +++ b/scripts/automated_ingestion/eessi_task.py @@ -0,0 +1,1404 @@ +from enum import Enum, auto +from typing import Dict, List, Tuple, Optional +from functools import total_ordering + +import base64 +import os +import subprocess +import traceback + +from eessi_data_object import EESSIDataAndSignatureObject +from eessi_task_action import EESSITaskAction +from eessi_task_description import EESSITaskDescription +from eessi_task_payload import EESSITaskPayload +from utils import send_slack_message, log_message, LoggingScope, log_function_entry_exit + +from github import Github, GithubException, InputGitTreeElement, UnknownObjectException +from github.PullRequest import PullRequest +from github.Branch import Branch + + +class SequenceStatus(Enum): + DOES_NOT_EXIST = auto() + IN_PROGRESS = auto() + FINISHED = auto() + + +@total_ordering +class TaskState(Enum): + UNDETERMINED = auto() # The task state was not determined yet + NEW_TASK = auto() # The task has been created but not yet processed + PAYLOAD_STAGED = auto() # The task's payload has been staged to the Stratum-0 + PULL_REQUEST = auto() # A PR for the task has been created or updated in some staging repository + APPROVED = auto() # The PR for the task has been approved + REJECTED = auto() # The PR for the task has been rejected + INGESTED = auto() # The task's payload has been applied to the target CernVM-FS repository + DONE = auto() # The task has been completed + + @classmethod + def from_string(cls, name, default=None, case_sensitive=False): + log_message(LoggingScope.TASK_OPS, 'INFO', "from_string: %s", name) + if case_sensitive: + to_return = cls.__members__.get(name, default) + log_message(LoggingScope.TASK_OPS, 'INFO', "from_string will return: %s", to_return) + return to_return + + try: + to_return = cls[name.upper()] + log_message(LoggingScope.TASK_OPS, 'INFO', "from_string will return: %s", to_return) + return to_return + except KeyError: + return default + + def __lt__(self, other): + if self.__class__ is other.__class__: + return self.value < other.value + return NotImplemented + + def __str__(self): + return self.name.upper() + + +class EESSITask: + description: EESSITaskDescription + payload: EESSITaskPayload + action: EESSITaskAction + git_repo: Github + config: Dict + + @log_function_entry_exit() + def __init__(self, description: EESSITaskDescription, config: Dict, cvmfs_repo: str, git_repo: Github): + self.description = description + self.config = config + self.cvmfs_repo = cvmfs_repo + self.git_repo = git_repo + self.action = self._determine_task_action() + + # Define valid state transitions for all actions + # NOTE, TaskState.APPROVED must be the first element or _next_state() will not work + self.valid_transitions = { + TaskState.UNDETERMINED: [TaskState.NEW_TASK, TaskState.PAYLOAD_STAGED, TaskState.PULL_REQUEST, + TaskState.APPROVED, TaskState.REJECTED, TaskState.INGESTED, TaskState.DONE], + TaskState.NEW_TASK: [TaskState.PAYLOAD_STAGED], + TaskState.PAYLOAD_STAGED: [TaskState.PULL_REQUEST], + TaskState.PULL_REQUEST: [TaskState.APPROVED, TaskState.REJECTED], + TaskState.APPROVED: [TaskState.INGESTED], + TaskState.REJECTED: [], # terminal state + TaskState.INGESTED: [], # terminal state + TaskState.DONE: [] # virtual terminal state, not used to write on GitHub + } + + self.payload = None + state = self.determine_state() + if state >= TaskState.PAYLOAD_STAGED: + log_message(LoggingScope.TASK_OPS, 'INFO', "initializing payload object in constructor for EESSITask") + self._init_payload_object() + + @log_function_entry_exit() + def _determine_task_action(self) -> EESSITaskAction: + """ + Determine the action type based on task description metadata. + """ + if 'task' in self.description.metadata and 'action' in self.description.metadata['task']: + action_str = self.description.metadata['task']['action'].lower() + if action_str == "nop": + return EESSITaskAction.NOP + elif action_str == "delete": + return EESSITaskAction.DELETE + elif action_str == "add": + return EESSITaskAction.ADD + elif action_str == "update": + return EESSITaskAction.UPDATE + return EESSITaskAction.UNKNOWN + + @log_function_entry_exit() + def _state_file_with_prefix_exists_in_repo_branch(self, file_path_prefix: str, branch_name: str = None) -> bool: + """ + Check if a file exists in a repository branch. + + Args: + file_path_prefix: the prefix of the file path + branch_name: the branch to check + + Returns: + True if a file with the prefix exists in the branch, False otherwise + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + # branch = self._get_branch_from_name(branch_name) + try: + # get all files in directory part of file_path_prefix + directory_part = os.path.dirname(file_path_prefix) + files = self.git_repo.get_contents(directory_part, ref=branch_name) + log_msg = "Found files %s in directory %s in branch %s" + log_message(LoggingScope.TASK_OPS, 'INFO', log_msg, files, directory_part, branch_name) + # check if any of the files has file_path_prefix as prefix + for file in files: + if file.path.startswith(file_path_prefix): + log_msg = "Found file %s in directory %s in branch %s" + log_message(LoggingScope.TASK_OPS, 'INFO', log_msg, file.path, directory_part, branch_name) + return True + log_msg = "No file with prefix %s found in directory %s in branch %s" + log_message(LoggingScope.TASK_OPS, 'INFO', log_msg, file_path_prefix, directory_part, branch_name) + return False + except UnknownObjectException: + # file_path does not exist in branch + log_msg = "Directory %s or file with prefix %s does not exist in branch %s" + log_message(LoggingScope.TASK_OPS, 'INFO', log_msg, directory_part, file_path_prefix, branch_name) + return False + except GithubException as err: + if err.status == 404: + # file_path does not exist in branch + log_msg = "Directory %s or file with prefix %s does not exist in branch %s" + log_message(LoggingScope.TASK_OPS, 'INFO', log_msg, directory_part, file_path_prefix, branch_name) + return False + else: + # if there was some other (e.g. connection) issue, log message and return False + log_msg = 'Unable to determine the state of %s, the GitHub API returned status %s!' + log_message(LoggingScope.ERROR, 'WARNING', log_msg, self.object, err.status) + return False + return False + + @log_function_entry_exit() + def _determine_sequence_numbers_including_task_file(self, repo: str, pr: str) -> Dict[int, bool]: + """ + Determines in which sequence numbers the metadata/task file is included and in which it is not. + NOTE, we only need to check the default branch of the repository, because a for a new task a file + is added to the default branch and for the subsequent processing of the task we use a different branch. + Thus, until the PR is closed, the task file stays in the default branch. + + Args: + repo: the repository name + pr: the pull request number + + Returns: + A dictionary with the sequence numbers as keys and a boolean value indicating if the metadata/task file is + included in that sequence number. + + Idea: + - The deployment for a single source PR could be split into multiple staging PRs each is assigned a unique + sequence number. + - For a given source PR (identified by the repo name and the PR number), a staging PR using a branch named + `REPO/PR_NUM/SEQ_NUM` is created. + - In the staging repo we create a corresponding directory `REPO/PR_NUM/SEQ_NUM`. + - If a metadata/task file is handled by the staging PR with sequence number, it is included in that directory. + - We iterate over all directories under `REPO/PR_NUM`: + - If the metadata/task file is available in the directory, we add the sequence number to the list. + + Note: this is a placeholder for now, as we do not know yet if we need to use a sequence number. + """ + sequence_numbers = {} + repo_pr_dir = f"{repo}/{pr}" + # iterate over all directories under repo_pr_dir + try: + directories = self._list_directory_contents(repo_pr_dir) + for dir in directories: + # check if the directory is a number + if dir.name.isdigit(): + # determine if a state file with prefix exists in the sequence number directory + # we need to use the basename of the remote file path + remote_file_path_basename = os.path.basename(self.description.task_object.remote_file_path) + state_file_name_prefix = f"{repo_pr_dir}/{dir.name}/{remote_file_path_basename}" + if self._state_file_with_prefix_exists_in_repo_branch(state_file_name_prefix): + sequence_numbers[int(dir.name)] = True + else: + sequence_numbers[int(dir.name)] = False + else: + # directory is not a number, so we skip it + continue + except FileNotFoundError: + # repo_pr_dir does not exist, so we return an empty dictionary + return {} + except GithubException as err: + if err.status != 404: # 404 is catched by FileNotFoundError + # some other error than the directory not existing + return {} + return sequence_numbers + + @log_function_entry_exit() + def _find_highest_number(self, str_list: List[str]) -> int: + """ + Find the highest number in a list of strings. + """ + # Convert all strings to integers + int_list = [int(num) for num in str_list] + return max(int_list) + + @log_function_entry_exit() + def _get_sequence_number_for_task_file(self) -> int: + """ + Get the sequence number this task is assigned to at the moment. + NOTE, should only be called if the task is actually assigned to a sequence number. + """ + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + sequence_numbers = self._determine_sequence_numbers_including_task_file(repo_name, pr_number) + if len(sequence_numbers) == 0: + raise ValueError("Found no sequence numbers at all") + else: + # get all entries with value True, there should be only one, so we return the first one + sequence_numbers_true = [key for key, value in sequence_numbers.items() if value is True] + if len(sequence_numbers_true) == 0: + raise ValueError("Found no sequence numbers that include the task file for task %s", + self.description) + else: + return sequence_numbers_true[0] + + @log_function_entry_exit() + def _get_current_sequence_number(self, sequence_numbers: Dict[int, bool] = None) -> int: + """ + Get the current sequence number based on the sequence numbers. + If sequence_numbers is not provided, we determine the sequence numbers from the task description. + """ + if sequence_numbers is None: + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + sequence_numbers = self._determine_sequence_numbers_including_task_file(repo_name, pr_number) + if len(sequence_numbers) == 0: + return 0 + return self._find_highest_number(sequence_numbers.keys()) + + @log_function_entry_exit() + def _get_fixed_sequence_number(self) -> int: + """ + Get a fixed sequence number. + """ + return 11 + + @log_function_entry_exit() + def _determine_sequence_status(self, sequence_number: int = None) -> int: + """ + Determine the status of the sequence number. It could be: DOES_NOT_EXIST, IN_PROGRESS, FINISHED + If sequence_number is not provided, we use the highest existing sequence number. + """ + if sequence_number is None: + sequence_number = self._get_current_sequence_number() + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + sequence_numbers = self._determine_sequence_numbers_including_task_file(repo_name, pr_number) + if len(sequence_numbers) == 0: + return SequenceStatus.DOES_NOT_EXIST + elif sequence_number not in sequence_numbers.keys(): + return SequenceStatus.DOES_NOT_EXIST + elif sequence_number < self._find_highest_number(sequence_numbers.keys()): + return SequenceStatus.FINISHED + else: + # check status of PR if it exists + branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}" + if branch_name in [branch.name for branch in self.git_repo.get_branches()]: + find_pr = [pr for pr in self.git_repo.get_pulls(head=branch_name, state='all')] + if find_pr: + pr = find_pr.pop(0) + if pr.state == 'closed': + return SequenceStatus.FINISHED + return SequenceStatus.IN_PROGRESS + + @log_function_entry_exit() + def _find_staging_pr(self) -> Tuple[Optional[PullRequest], Optional[str], Optional[int]]: + """ + Find the staging PR for the task. + TODO: arg sequence number --> make function simpler + """ + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + try: + sequence_number = self._get_sequence_number_for_task_file() + except ValueError: + # no sequence number found, so we return None + log_message(LoggingScope.ERROR, 'ERROR', "no sequence number found for task %s", self.description) + return None, None, None + except Exception as err: + # some other error + log_message(LoggingScope.ERROR, 'ERROR', "error finding staging PR for task %s: %s", + self.description, err) + return None, None, None + branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}" + if branch_name in [branch.name for branch in self.git_repo.get_branches()]: + find_pr = [pr for pr in self.git_repo.get_pulls(head=branch_name, state='all')] + if find_pr: + pr = find_pr.pop(0) + return pr, branch_name, sequence_number + else: + return None, branch_name, sequence_number + else: + return None, None, None + + @log_function_entry_exit() + def _create_staging_pr(self, sequence_number: int) -> Tuple[PullRequest, str]: + """ + Create a staging PR for the task. + NOTE, SHALL only be called if no staging PR for the task exists yet. + """ + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}" + default_branch_name = self.git_repo.default_branch + pr = self.git_repo.create_pull(title=f"Add task for {repo_name} PR {pr_number} seq {sequence_number}", + body=f"Add task for {repo_name} PR {pr_number} seq {sequence_number}", + head=branch_name, base=default_branch_name) + return pr, branch_name + + @log_function_entry_exit() + def _find_state(self) -> TaskState: + """ + Determine the state of the task based on the task description metadata. + + Returns: + The state of the task. + """ + # obtain repo and pr from metadata + log_message(LoggingScope.TASK_OPS, 'INFO', "finding state of task %s", self.description.task_object) + repo = self.description.get_repo_name() + pr = self.description.get_pr_number() + log_message(LoggingScope.TASK_OPS, 'INFO', "repo: %s, pr: %s", repo, pr) + + # obtain all sequence numbers in repo/pr dir which include a state file for this task + sequence_numbers = self._determine_sequence_numbers_including_task_file(repo, pr) + if len(sequence_numbers) == 0: + # no sequence numbers found, so we return NEW_TASK + log_message(LoggingScope.TASK_OPS, 'INFO', "no sequence numbers found, state: NEW_TASK") + return TaskState.NEW_TASK + # we got at least one sequence number + # if one value for a sequence number is True, we can determine the state from the file in the directory + sequence_including_task = [key for key, value in sequence_numbers.items() if value is True] + if len(sequence_including_task) == 0: + # no sequence number includes the task file, so we return NEW_TASK + log_message(LoggingScope.TASK_OPS, 'INFO', "no sequence number includes the task file, state: NEW_TASK") + return TaskState.NEW_TASK + # we got at least one sequence number which includes the task file + # we can determine the state from the filename in the directory + # NOTE, we use the first element in sequence_including_task (there should be only one) + # we ignore other elements in sequence_including_task + sequence_number = sequence_including_task[0] + task_file_name = self.description.get_task_file_name() + metadata_file_state_path_prefix = f"{repo}/{pr}/{sequence_number}/{task_file_name}." + state = self._get_state_for_metadata_file_prefix(metadata_file_state_path_prefix, sequence_number) + log_message(LoggingScope.TASK_OPS, 'INFO', "state: %s", state) + return state + + @log_function_entry_exit() + def _get_state_for_metadata_file_prefix(self, metadata_file_state_path_prefix: str, + sequence_number: int) -> TaskState: + """ + Get the state from the file in the metadata_file_state_path_prefix. + """ + # depending on the state of the deployment (NEW_TASK, PAYLOAD_STAGED, PULL_REQUEST, APPROVED, REJECTED, + # INGESTED, DONE) + # we need to check the task file in the default branch or in the branch corresponding to the sequence number + directory_part = os.path.dirname(metadata_file_state_path_prefix) + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + default_branch_name = self.git_repo.default_branch + branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}" + all_branch_names = [branch.name for branch in self.git_repo.get_branches()] + states = [] + for branch in [default_branch_name, branch_name]: + if branch in all_branch_names: + # first get all files in directory part of metadata_file_state_path_prefix + files = self._list_directory_contents(directory_part, branch) + # check if any of the files has metadata_file_state_path_prefix as prefix + for file in files: + if file.path.startswith(metadata_file_state_path_prefix): + # get state from file name taking only the suffix + state = TaskState.from_string(file.name.split('.')[-1]) + log_message(LoggingScope.TASK_OPS, 'INFO', "state: %s", state) + states.append(state) + if len(states) == 0: + # did not find any file with metadata_file_state_path_prefix as prefix + log_message(LoggingScope.TASK_OPS, 'INFO', "did not find any file with prefix %s", + metadata_file_state_path_prefix) + return TaskState.NEW_TASK + # sort the states and return the last one + states.sort() + state = states[-1] + log_message(LoggingScope.TASK_OPS, 'INFO', "state: %s", state) + return state + + @log_function_entry_exit() + def _list_directory_contents(self, directory_path, branch_name: str = None): + try: + # Get contents of the directory + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + log_message(LoggingScope.TASK_OPS, 'INFO', + "listing contents of %s in branch %s", directory_path, branch_name) + contents = self.git_repo.get_contents(directory_path, ref=branch_name) + + # If contents is a list, it means we successfully got directory contents + if isinstance(contents, list): + return contents + else: + # If it's not a list, it means the path is not a directory + raise ValueError(f"{directory_path} is not a directory") + except GithubException as err: + if err.status == 404: + raise FileNotFoundError(f"Directory not found: {directory_path}") + raise err + + @log_function_entry_exit() + def _next_state(self, state: TaskState = None) -> TaskState: + """ + Determine the next state based on the current state using the valid_transitions dictionary. + + NOTE, it assumes that function is only called for non-terminal states and that the next state is the first + element of the list returned by the valid_transitions dictionary. + """ + the_state = state if state is not None else self.determine_state() + return self.valid_transitions[the_state][0] + + @log_function_entry_exit() + def _path_exists_in_branch(self, path: str, branch_name: str = None) -> bool: + """ + Check if a path exists in a branch. + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + try: + self.git_repo.get_contents(path, ref=branch_name) + return True + except GithubException as err: + if err.status == 404: + return False + else: + raise err + + @log_function_entry_exit() + def _read_dict_from_string(self, content: str) -> dict: + """ + Read the dictionary from the string. + """ + config_dict = {} + for line in content.strip().split('\n'): + if '=' in line and not line.strip().startswith('#'): # Skip comments + key, value = line.split('=', 1) # Split only on first '=' + config_dict[key.strip()] = value.strip() + return config_dict + + @log_function_entry_exit() + def _read_pull_request_dir_from_file(self, task_pointer_file: str = None, branch_name: str = None) -> str: + """ + Read the pull request directory from the file in the given branch. + """ + # set default values for task pointer file and branch name + if task_pointer_file is None: + task_pointer_file = self.description.task_object.remote_file_path + if branch_name is None: + branch_name = self.git_repo.default_branch + log_message(LoggingScope.TASK_OPS, 'INFO', "reading pull request directory from file '%s' in branch '%s'", + task_pointer_file, branch_name) + + # read the pull request directory from the file in the given branch + content = self.git_repo.get_contents(task_pointer_file, ref=branch_name) + + # Decode the content from base64 + content_str = content.decoded_content.decode('utf-8') + + # Parse into dictionary + config_dict = self._read_dict_from_string(content_str) + + target_dir = config_dict.get('target_dir', None) + return config_dict.get('pull_request_dir', target_dir) + + @log_function_entry_exit() + def _determine_pull_request_dir(self, task_pointer_file: str = None, branch_name: str = None) -> str: + """Determine the pull request directory via the task pointer file""" + return self._read_pull_request_dir_from_file(task_pointer_file=task_pointer_file, branch_name=branch_name) + + @log_function_entry_exit() + def _get_branch_from_name(self, branch_name: str = None) -> Optional[Branch]: + """ + Get a branch object from its name. + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + + try: + branch = self.git_repo.get_branch(branch_name) + log_message(LoggingScope.TASK_OPS, 'INFO', "branch %s exists: %s", branch_name, branch) + return branch + except Exception as err: + log_message(LoggingScope.TASK_OPS, 'ERROR', "error checking if branch %s exists: %s", + branch_name, err) + return None + + @log_function_entry_exit() + def _read_task_state_from_file(self, path: str, branch_name: str = None) -> TaskState: + """ + Read the task state from the file in the given branch. + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + content = self.git_repo.get_contents(path, ref=branch_name) + + # Decode the content from base64 + content_str = content.decoded_content.decode('utf-8').strip() + log_message(LoggingScope.TASK_OPS, 'INFO', "content in TaskState file: %s", content_str) + + task_state = TaskState.from_string(content_str) + log_message(LoggingScope.TASK_OPS, 'INFO', "task state: %s", task_state) + + return task_state + + @log_function_entry_exit() + def determine_state(self, branch: str = None) -> TaskState: + """ + Determine the state of the task based on the state of the staging repository. + """ + # check if path representing the task file exists in the default branch or the "feature" branch + task_pointer_file = self.description.task_object.remote_file_path + branch_to_use = self.git_repo.default_branch if branch is None else branch + + if self._path_exists_in_branch(task_pointer_file, branch_name=branch_to_use): + log_message(LoggingScope.TASK_OPS, 'INFO', "path '%s' exists in branch '%s'", + task_pointer_file, branch_to_use) + + # get state from task file in branch to use + # - read the TaskState file in pull request directory + pull_request_dir = self._determine_pull_request_dir(branch_name=branch_to_use) + log_message(LoggingScope.TASK_OPS, 'INFO', "pull request directory: '%s'", pull_request_dir) + task_state_file_path = f"{pull_request_dir}/TaskState" + log_message(LoggingScope.TASK_OPS, 'INFO', "task state file path: '%s'", task_state_file_path) + task_state = self._read_task_state_from_file(task_state_file_path, branch_to_use) + + log_message(LoggingScope.TASK_OPS, 'INFO', "task state in branch '%s': %s", + branch_to_use, task_state) + return task_state + else: + log_message(LoggingScope.TASK_OPS, 'INFO', "path '%s' does not exist in branch '%s'", + task_pointer_file, branch_to_use) + return TaskState.UNDETERMINED + + @log_function_entry_exit() + def handle(self): + """ + Dynamically find and execute the appropriate handler based on action and state. + """ + state_before_handle = self.determine_state() + + # Construct handler method name + handler_name = f"_handle_{self.action}_{str(state_before_handle).lower()}" + + # Check if the handler exists + handler = getattr(self, handler_name, None) + + if handler and callable(handler): + # Execute the handler if it exists + return handler() + else: + # Default behavior for missing handlers + log_message(LoggingScope.TASK_OPS, 'ERROR', + "No handler for action %s and state %s implemented; nothing to be done", + self.action, state_before_handle) + return state_before_handle + + # Implement handlers for ADD action + @log_function_entry_exit() + def _safe_create_file(self, path: str, message: str, content: str, branch_name: str = None): + """Create a file in the given branch.""" + try: + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + existing_file = self.git_repo.get_contents(path, ref=branch_name) + log_message(LoggingScope.TASK_OPS, 'INFO', "File %s already exists", path) + return existing_file + except GithubException as err: + if err.status == 404: # File doesn't exist + # Safe to create + return self.git_repo.create_file(path, message, content, branch=branch_name) + else: + raise err # Some other error + + @log_function_entry_exit() + def _create_multi_file_commit(self, files_data, commit_message, branch_name: str = None): + """ + Create a commit with multiple file changes + + files_data: dict with structure: + { + "path/to/file1.txt": { + "content": "file content", + "mode": "100644" # optional, defaults to 100644 + }, + "path/to/file2.py": { + "content": "print('hello')", + "mode": "100644" + } + } + """ + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + ref = self.git_repo.get_git_ref(f"heads/{branch_name}") + current_commit = self.git_repo.get_git_commit(ref.object.sha) + base_tree = current_commit.tree + + # Create tree elements + tree_elements = [] + for file_path, file_info in files_data.items(): + content = file_info["content"] + if isinstance(content, str): + content = content.encode('utf-8') + + blob = self.git_repo.create_git_blob( + base64.b64encode(content).decode('utf-8'), + "base64" + ) + tree_elements.append(InputGitTreeElement( + path=file_path, + mode=file_info.get("mode", "100644"), + type="blob", + sha=blob.sha + )) + + # Create new tree + new_tree = self.git_repo.create_git_tree(tree_elements, base_tree) + + # Create commit + new_commit = self.git_repo.create_git_commit( + commit_message, + new_tree, + [current_commit] + ) + + # Update branch reference + ref.edit(new_commit.sha) + + return new_commit + + @log_function_entry_exit() + def _update_file(self, file_path, new_content, commit_message, branch_name: str = None) -> Optional[Dict]: + try: + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + + # Get the current file + file = self.git_repo.get_contents(file_path, ref=branch_name) + + # Update the file + result = self.git_repo.update_file( + path=file_path, + message=commit_message, + content=new_content, + sha=file.sha, + branch=branch_name + ) + + log_message(LoggingScope.TASK_OPS, 'INFO', + "File updated successfully. Commit SHA: %s", result['commit'].sha) + return result + + except Exception as err: + log_message(LoggingScope.TASK_OPS, 'ERROR', "Error updating file: %s", err) + return None + + @log_function_entry_exit() + def _sorted_list_of_sequence_numbers(self) -> List[int]: + """Create a sorted list of sequence numbers from the pull requests directory""" + # a pull request's directory is of the form REPO/PR/SEQ + # hence, we can get all sequence numbers from the pull requests directory REPO/PR + sequence_numbers = [] + repo_pr_dir = f"{self.description.get_repo_name()}/{self.description.get_pr_number()}" + + # iterate over all directories under repo_pr_dir + try: + directories = self._list_directory_contents(repo_pr_dir) + for dir in directories: + # check if the directory is a number + if dir.name.isdigit(): + sequence_numbers.append(int(dir.name)) + else: + # directory is not a number, so we skip it + continue + except FileNotFoundError: + # repo_pr_dir does not exist, so we return an empty dictionary + log_message(LoggingScope.TASK_OPS, 'ERROR', "Pull requests directory '%s' does not exist", repo_pr_dir) + except GithubException as err: + if err.status != 404: # 404 is catched by FileNotFoundError + # some other error than the directory not existing + log_message(LoggingScope.TASK_OPS, 'ERROR', + "Some other error than the directory not existing: %s", err) + except Exception as err: + log_message(LoggingScope.TASK_OPS, 'ERROR', "Unexpected error: %s", err) + + return sorted(sequence_numbers) + + @log_function_entry_exit() + def _determine_sequence_number(self) -> int: + """Determine the sequence number for the task""" + + sequence_numbers = self._sorted_list_of_sequence_numbers() + log_message(LoggingScope.TASK_OPS, 'INFO', "number of sequence numbers: %d", len(sequence_numbers)) + if len(sequence_numbers) == 0: + return 0 + + log_message(LoggingScope.TASK_OPS, 'INFO', "sequence numbers: [%s]", ", ".join(map(str, sequence_numbers))) + + # get the highest sequence number + highest_sequence_number = sequence_numbers[-1] + log_message(LoggingScope.TASK_OPS, 'INFO', "highest sequence number: %d", highest_sequence_number) + + pull_request = self._find_pr_for_sequence_number(highest_sequence_number) + log_message(LoggingScope.TASK_OPS, 'INFO', "pull request: %s", pull_request) + + if pull_request is None: + log_message(LoggingScope.TASK_OPS, 'INFO', "Did not find pull request for sequence number %d", + highest_sequence_number) + # the directory for the sequence number exists but no PR yet + return highest_sequence_number + else: + log_message(LoggingScope.TASK_OPS, 'INFO', "pull request found: %s", pull_request) + log_message(LoggingScope.TASK_OPS, 'INFO', "pull request state/merged: %s/%s", + pull_request.state, str(pull_request.is_merged())) + if pull_request.is_merged(): + # the PR is merged, so we use the next sequence number + return highest_sequence_number + 1 + else: + # the PR is not merged, so we can use the current sequence number + return highest_sequence_number + + @log_function_entry_exit() + def _handle_add_undetermined(self): + """Handler for ADD action in UNDETERMINED state""" + print("Handling ADD action in UNDETERMINED state: %s" % self.description.get_task_file_name()) + # task is in state UNDETERMINED if there is no pull request directory for the task yet + # + # create pull request directory (REPO/PR/SEQ/TASK_FILE_NAME/) + # create task file in pull request directory (PULL_REQUEST_DIR/TaskDescription) + # create task status file in pull request directory (PULL_REQUEST_DIR/TaskState.NEW_TASK) + # create pointer file from task file path to pull request directory (remote_file_path -> PULL_REQUEST_DIR) + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + sequence_number = self._determine_sequence_number() # corresponds to an open or yet to be created PR + task_file_name = self.description.get_task_file_name() + # we cannot use self._determine_pull_request_dir() here because it requires a task pointer file + # and we don't have one yet + pull_request_dir = f"{repo_name}/{pr_number}/{sequence_number}/{task_file_name}" + task_description_file_path = f"{pull_request_dir}/TaskDescription" + task_state_file_path = f"{pull_request_dir}/TaskState" + remote_file_path = self.description.task_object.remote_file_path + + files_to_commit = { + task_description_file_path: { + "content": self.description.get_contents(), + "mode": "100644" + }, + task_state_file_path: { + "content": f"{TaskState.NEW_TASK.name}\n", + "mode": "100644" + }, + remote_file_path: { + "content": f"remote_file_path = {remote_file_path}\npull_request_dir = {pull_request_dir}", + "mode": "100644" + } + } + + branch_name = self.git_repo.default_branch + try: + commit = self._create_multi_file_commit( + files_to_commit, + f"new task for {repo_name} PR {pr_number} seq {sequence_number}", + branch_name=branch_name + ) + log_message(LoggingScope.TASK_OPS, 'INFO', "commit created: %s", commit) + except Exception as err: + log_message(LoggingScope.TASK_OPS, 'ERROR', "Error creating commit: %s", err) + # TODO: rollback previous changes (task description file, task state file) + return TaskState.UNDETERMINED + + # TODO: verify that the sequence number is still valid (PR corresponding to the sequence number + # is still open or yet to be created); if it is not valid, perform corrective actions + return TaskState.NEW_TASK + + @log_function_entry_exit() + def _update_task_state_file(self, next_state: TaskState, branch_name: str = None) -> Optional[Dict]: + """Update the TaskState file content in default or given branch""" + branch_name = self.git_repo.default_branch if branch_name is None else branch_name + + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, branch_name) + task_state_file_path = f"{pull_request_dir}/TaskState" + arch = self.description.get_metadata_file_components()[3] + commit_message = f"change task state to {next_state} in {branch_name} for {arch}" + result = self._update_file(task_state_file_path, + f"{next_state.name}\n", + commit_message, + branch_name=branch_name) + return result + + @log_function_entry_exit() + def _init_payload_object(self): + """Initialize the payload object""" + if self.payload is not None: + log_message(LoggingScope.TASK_OPS, 'INFO', "payload object already initialized") + return + + # get name of of payload from metadata + payload_name = self.description.metadata['payload']['filename'] + log_message(LoggingScope.TASK_OPS, 'INFO', "payload_name: %s", payload_name) + + # get config and remote_client from self.description.task_object + config = self.description.task_object.config + remote_client = self.description.task_object.remote_client + + # determine remote_file_path by replacing basename of remote_file_path in self.description.task_object + # with payload_name + description_remote_file_path = self.description.task_object.remote_file_path + payload_remote_file_path = os.path.join(os.path.dirname(description_remote_file_path), payload_name) + log_message(LoggingScope.TASK_OPS, 'INFO', "payload_remote_file_path: %s", payload_remote_file_path) + + # initialize payload object + payload_object = EESSIDataAndSignatureObject(config, payload_remote_file_path, remote_client) + self.payload = EESSITaskPayload(payload_object) + log_message(LoggingScope.TASK_OPS, 'INFO', "payload: %s", self.payload) + + @log_function_entry_exit() + def _handle_add_new_task(self): + """Handler for ADD action in NEW_TASK state""" + print("Handling ADD action in NEW_TASK state: %s" % self.description.get_task_file_name()) + # determine next state + next_state = self._next_state(TaskState.NEW_TASK) + log_message(LoggingScope.TASK_OPS, 'INFO', "next_state: %s", next_state) + + # initialize payload object + self._init_payload_object() + + # update TaskState file content + self._update_task_state_file(next_state) + + # TODO: verify that the sequence number is still valid (PR corresponding to the sequence number + # is still open or yet to be created); if it is not valid, perform corrective actions + return next_state + + @log_function_entry_exit() + def _find_pr_for_branch(self, branch_name: str) -> Optional[PullRequest]: + """ + Find the single PR for the given branch in any state. + + Args: + repo: GitHub repository + branch_name: Name of the branch + + Returns: + PullRequest object if found, None otherwise + """ + try: + prs = [pr for pr in list(self.git_repo.get_pulls(state='all')) + if pr.head.ref == branch_name] + log_message(LoggingScope.TASK_OPS, 'INFO', "number of PRs found: %d", len(prs)) + if len(prs): + log_message(LoggingScope.TASK_OPS, 'INFO', "1st PR found: %d, %s", prs[0].number, prs[0].head.ref) + return prs[0] if prs else None + except Exception as err: + log_message(LoggingScope.TASK_OPS, 'ERROR', "Error finding PR for branch %s: %s", branch_name, err) + return None + + @log_function_entry_exit() + def _find_pr_for_sequence_number(self, sequence_number: int) -> Optional[PullRequest]: + """Find the PR for the given sequence number""" + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + feature_branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}" + + # list all PRs with head_ref starting with the feature branch name without the sequence number + last_dash = feature_branch_name.rfind('-') + if last_dash != -1: + head_ref_wout_seq_num = feature_branch_name[:last_dash + 1] # +1 to include the separator + else: + head_ref_wout_seq_num = feature_branch_name + + log_message(LoggingScope.TASK_OPS, 'INFO', + "searching for PRs whose head_ref starts with: '%s'", head_ref_wout_seq_num) + + all_prs = [pr for pr in list(self.git_repo.get_pulls(state='all')) + if pr.head.ref.startswith(head_ref_wout_seq_num)] + log_message(LoggingScope.TASK_OPS, 'INFO', " number of PRs found: %d", len(all_prs)) + for pr in all_prs: + log_message(LoggingScope.TASK_OPS, 'INFO', " PR #%d: %s", pr.number, pr.head.ref) + + # now, find the PR for the feature branch name (if any) + log_message(LoggingScope.TASK_OPS, 'INFO', + "searching PR for feature branch name: '%s'", feature_branch_name) + pull_request = self._find_pr_for_branch(feature_branch_name) + log_message(LoggingScope.TASK_OPS, 'INFO', "pull request for branch '%s': %s", + feature_branch_name, pull_request) + return pull_request + + @log_function_entry_exit() + def _determine_sequence_number_from_pull_request_directory(self) -> int: + """Determine the sequence number from the pull request directory name""" + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, self.git_repo.default_branch) + # pull_request_dir is of the form REPO/PR/SEQ/TASK_FILE_NAME/ (REPO contains a '/' separating the org and repo) + _, _, _, seq, _ = pull_request_dir.split('/') + return int(seq) + + @log_function_entry_exit() + def _determine_feature_branch_name(self) -> str: + """Determine the feature branch name from the pull request directory name""" + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, self.git_repo.default_branch) + # pull_request_dir is of the form REPO/PR/SEQ/TASK_FILE_NAME/ (REPO contains a '/' separating the org and repo) + org, repo, pr, seq, _ = pull_request_dir.split('/') + return f"{org}-{repo}-PR-{pr}-SEQ-{seq}" + + @log_function_entry_exit() + def _sync_task_state_file(self, source_branch: str, target_branch: str): + """Update task state file from source to target branch""" + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, self.git_repo.default_branch) + task_state_file_path = f"{pull_request_dir}/TaskState" + + try: + # Get content from source branch + source_content = self.git_repo.get_contents(task_state_file_path, ref=source_branch) + + # Get current file in target branch + target_file = self.git_repo.get_contents(task_state_file_path, ref=target_branch) + + # Update if content is different + if source_content.sha != target_file.sha: + result = self.git_repo.update_file( + path=task_state_file_path, + message=f"Sync {task_state_file_path} from {source_branch} to {target_branch}", + content=source_content.decoded_content, + sha=target_file.sha, + branch=target_branch + ) + log_message(LoggingScope.TASK_OPS, 'INFO', "Updated %s", task_state_file_path) + return result + else: + log_message(LoggingScope.TASK_OPS, 'INFO', "No changes needed for %s", task_state_file_path) + return None + + except Exception as err: + log_message(LoggingScope.TASK_OPS, 'ERROR', "Error syncing task state file: %s", err) + return None + + @log_function_entry_exit() + def _update_task_states(self, next_state: TaskState, default_branch_name: str, + approved_state: TaskState, feature_branch_name: str): + """ + Update task states in default and feature branches + + States have to be updated in a specific order and in particular the default branch has to be + merged into the feature branch before the feature branch can be updated to avoid a merge conflict. + + Args: + next_state: next state to be applied to the default branch + default_branch_name: name of the default branch + approved_state: state to be applied to the feature branch + feature_branch_name: name of the feature branch + """ + # TODO: add failure handling (capture failures and return them somehow) + + # update TaskState file content + # - next_state in default branch (interpreted as current state) + # - approved_state in feature branch (interpreted as future state, ie, after + # the PR corresponding to the feature branch will be merged) + + # first, update the task state file in the default branch + self._update_task_state_file(next_state, branch_name=default_branch_name) + + # second, merge default branch into feature branch (to avoid a merge conflict) + # TODO: store arch info (CPU+ACCEL) in task/metdata file and then access that rather + # than using a part of the file name + arch = self.description.get_metadata_file_components()[3] + commit_message = f"merge {default_branch_name} into {feature_branch_name} for {arch}" + self.git_repo.merge( + head=default_branch_name, + base=feature_branch_name, + commit_message=commit_message + ) + + # last, update task state file in feature branch + self._update_task_state_file(approved_state, branch_name=feature_branch_name) + log_message(LoggingScope.TASK_OPS, 'INFO', + "TaskState file updated to %s in default branch (%s) and to %s in feature branch (%s)", + next_state, default_branch_name, approved_state, feature_branch_name) + + @log_function_entry_exit() + def _create_task_summary(self) -> str: + """Analyse contents of current task and create a file for it in the REPO-PR-SEQ directory.""" + + # determine task summary file path in feature branch on GitHub + feature_branch_name = self._determine_feature_branch_name() + pull_request_dir = self._determine_pull_request_dir(branch_name=feature_branch_name) + task_summary_file_path = f"{pull_request_dir}/TaskSummary.html" + + # check if task summary file already exists in repo on GitHub + if self._path_exists_in_branch(task_summary_file_path, feature_branch_name): + log_message(LoggingScope.TASK_OPS, 'INFO', "task summary file already exists: %s", task_summary_file_path) + task_summary = self.git_repo.get_contents(task_summary_file_path, ref=feature_branch_name) + # return task_summary.decoded_content + return task_summary + + # create task summary + payload_name = self.description.metadata['payload']['filename'] + payload_summary = self.payload.analyse_contents(self.config) + metadata_contents = self.description.get_contents() + + task_summary = self.config['github']['task_summary_payload_template'].format( + payload_name=payload_name, + metadata_contents=metadata_contents, + payload_overview=payload_summary + ) + + # create HTML file with task summary in REPO-PR-SEQ directory + # TODO: add failure handling (capture result and act on it) + task_file_name = self.description.get_task_file_name() + commit_message = f"create summary for {task_file_name} in {feature_branch_name}" + self._safe_create_file(task_summary_file_path, commit_message, task_summary, + branch_name=feature_branch_name) + log_message(LoggingScope.TASK_OPS, 'INFO', "task summary file created: %s", task_summary_file_path) + + # return task summary + return task_summary + + @log_function_entry_exit() + def _create_pr_contents_overview(self) -> str: + """Create a contents overview for the pull request""" + # TODO: implement + feature_branch_name = self._determine_feature_branch_name() + task_pointer_file = self.description.task_object.remote_file_path + pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, feature_branch_name) + pr_dir = os.path.dirname(pull_request_dir) + directories = self._list_directory_contents(pr_dir, feature_branch_name) + contents_overview = "" + if directories: + contents_overview += "\n" + for directory in directories: + task_summary_file_path = f"{pr_dir}/{directory.name}/TaskSummary.html" + if self._path_exists_in_branch(task_summary_file_path, feature_branch_name): + file_contents = self.git_repo.get_contents(task_summary_file_path, ref=feature_branch_name) + task_summary = base64.b64decode(file_contents.content).decode('utf-8') + contents_overview += f"{task_summary}\n" + else: + contents_overview += f"Task summary file not found: {task_summary_file_path}\n" + contents_overview += "\n" + else: + contents_overview += "No tasks found in this PR\n" + + print(f"contents_overview: {contents_overview}") + return contents_overview + + @log_function_entry_exit() + def _create_pull_request(self, feature_branch_name: str, default_branch_name: str): + """ + Create a PR from the feature branch to the default branch + + Args: + feature_branch_name: name of the feature branch + default_branch_name: name of the default branch + """ + pr_title_format = self.config['github']['grouped_pr_title'] + pr_body_format = self.config['github']['grouped_pr_body'] + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + pr_url = f"https://github.com/{repo_name}/pull/{pr_number}" + seq_num = self._determine_sequence_number_from_pull_request_directory() + pr_title = pr_title_format.format( + cvmfs_repo=self.cvmfs_repo, + pr=pr_number, + repo=repo_name, + seq_num=seq_num, + ) + self._create_task_summary() + contents_overview = self._create_pr_contents_overview() + pr_body = pr_body_format.format( + cvmfs_repo=self.cvmfs_repo, + pr=pr_number, + pr_url=pr_url, + repo=repo_name, + seq_num=seq_num, + contents=contents_overview, + analysis="
TO BE DONE
", + action="
TO BE DONE
", + ) + pr = self.git_repo.create_pull( + title=pr_title, + body=pr_body, + head=feature_branch_name, + base=default_branch_name + ) + log_message(LoggingScope.TASK_OPS, 'INFO', "PR created: %s", pr) + + @log_function_entry_exit() + def _update_pull_request(self, pull_request: PullRequest): + """ + Update the pull request + + Args: + pull_request: instance of the pull request + """ + # TODO: update sections (contents analysis, action) + repo_name = self.description.get_repo_name() + pr_number = self.description.get_pr_number() + pr_url = f"https://github.com/{repo_name}/pull/{pr_number}" + seq_num = self._determine_sequence_number_from_pull_request_directory() + + self._create_task_summary() + contents_overview = self._create_pr_contents_overview() + pr_body_format = self.config['github']['grouped_pr_body'] + pr_body = pr_body_format.format( + cvmfs_repo=self.cvmfs_repo, + pr=pr_number, + pr_url=pr_url, + repo=repo_name, + seq_num=seq_num, + contents=contents_overview, + analysis="
TO BE DONE
", + action="
TO BE DONE
", + ) + pull_request.edit(body=pr_body) + + log_message(LoggingScope.TASK_OPS, 'INFO', "PR updated: %s", pull_request) + + @log_function_entry_exit() + def _handle_add_payload_staged(self): + """Handler for ADD action in PAYLOAD_STAGED state""" + print("Handling ADD action in PAYLOAD_STAGED state: %s" % self.description.get_task_file_name()) + next_state = self._next_state(TaskState.PAYLOAD_STAGED) + approved_state = TaskState.APPROVED + log_message(LoggingScope.TASK_OPS, 'INFO', "next_state: %s, approved_state: %s", next_state, approved_state) + + default_branch_name = self.git_repo.default_branch + default_branch = self._get_branch_from_name(default_branch_name) + default_sha = default_branch.commit.sha + feature_branch_name = self._determine_feature_branch_name() + feature_branch = self._get_branch_from_name(feature_branch_name) + if not feature_branch: + # feature branch does not exist + # TODO: could have been merged already --> check if PR corresponding to the feature branch exists + # ASSUME: it has not existed before --> create it + log_message(LoggingScope.TASK_OPS, 'INFO', + "branch %s does not exist, creating it", feature_branch_name) + + feature_branch = self.git_repo.create_git_ref(f"refs/heads/{feature_branch_name}", default_sha) + log_message(LoggingScope.TASK_OPS, 'INFO', + "branch %s created: %s", feature_branch_name, feature_branch) + else: + log_message(LoggingScope.TASK_OPS, 'INFO', + "found existing branch for %s: %s", feature_branch_name, feature_branch) + + pull_request = self._find_pr_for_branch(feature_branch_name) + if not pull_request: + log_message(LoggingScope.TASK_OPS, 'INFO', + "no PR found for branch %s", feature_branch_name) + + # TODO: add failure handling (capture result and act on it) + self._update_task_states(next_state, default_branch_name, approved_state, feature_branch_name) + + # TODO: add failure handling (capture result and act on it) + self._create_pull_request(feature_branch_name, default_branch_name) + + return TaskState.PULL_REQUEST + else: + log_message(LoggingScope.TASK_OPS, 'INFO', + "found existing PR for branch %s: %s", feature_branch_name, pull_request) + # TODO: check if PR is open or closed + if pull_request.state == 'closed': + log_message(LoggingScope.TASK_OPS, 'INFO', + "PR %s is closed, creating issue", pull_request) + # TODO: create issue + return TaskState.PAYLOAD_STAGED + else: + log_message(LoggingScope.TASK_OPS, 'INFO', + "PR %s is open, updating task states", pull_request) + # TODO: add failure handling (capture result and act on it) + # THINK about what a failure would mean and what to do about it. + self._update_task_states(next_state, default_branch_name, approved_state, feature_branch_name) + + # TODO: add failure handling (capture result and act on it) + self._update_pull_request(pull_request) + + return TaskState.PULL_REQUEST + + @log_function_entry_exit() + def _handle_add_pull_request(self): + """Handler for ADD action in PULL_REQUEST state""" + print("Handling ADD action in PULL_REQUEST state: %s" % self.description.get_task_file_name()) + # Implementation for adding in PULL_REQUEST state + # we got here because the state of the task is PULL_REQUEST in the default branch + # determine branch and PR and state of PR + # PR is open --> just return TaskState.PULL_REQUEST + # PR is closed & merged --> deployment is approved + # PR is closed & not merged --> deployment is rejected + feature_branch_name = self._determine_feature_branch_name() + # TODO: check if feature branch exists, for now ASSUME it does + pull_request = self._find_pr_for_branch(feature_branch_name) + if pull_request: + log_message(LoggingScope.TASK_OPS, 'INFO', + "found PR for branch %s: %s", feature_branch_name, pull_request) + if pull_request.state == 'closed': + if pull_request.merged: + log_message(LoggingScope.TASK_OPS, 'INFO', + "PR %s is closed and merged, returning APPROVED state", pull_request) + # TODO: How could we ended up here? state in default branch is PULL_REQUEST but + # PR is merged, hence it should have been in the APPROVED state + # ==> for now, just return TaskState.PULL_REQUEST + # + # there is the possibility that the PR was updated just before the + # PR was merged + # WHY is it a problem? because a task may have been accepted that wouldn't + # have been accepted or worse shouldn't been accepted + # WHAT to do? ACCEPT/IGNORE THE ISSUE FOR NOw + # HOWEVER, the contents of the PR directory may be inconsistent with + # respect to the TaskState file and missing TaskSummary.html file + # WE could create an issue and only return TaskState.APPROVED if the + # issue is closed + # WE could also defer all handling of this to the handler for the + # APPROVED state + # NOPE, we have to do some handling here, at least for the tasks where their + # state file did + # --> check if we could have ended up here? If so, create an issue. + # Do we need a state ISSUE_OPENED to avoid processing the task again? + return TaskState.PULL_REQUEST + else: + log_message(LoggingScope.TASK_OPS, 'INFO', + "PR %s is closed and not merged, returning REJECTED state", pull_request) + # TODO: there is the possibility that the PR was updated just before the + # PR was closed + # WHY is it a problem? because a task may have been rejected that wouldn't + # have been rejected or worse shouldn't been rejected + # WHAT to do? ACCEPT/IGNORE THE ISSUE FOR NOw + # HOWEVER, the contents of the PR directory may be inconsistent with + # respect to the TaskState file and missing TaskSummary.html file + # WE could create an issue and only return TaskState.REJECTED if the + # issue is closed + # WE could also defer all handling of this to the handler for the + # REJECTED state + # FOR NOW, we assume that the task was rejected on purpose + # we need to change the state of the task in the default branch to REJECTED + self._update_task_state_file(TaskState.REJECTED) + return TaskState.REJECTED + else: + log_message(LoggingScope.TASK_OPS, 'INFO', + "PR %s is open, returning PULL_REQUEST state", pull_request) + return TaskState.PULL_REQUEST + else: + log_message(LoggingScope.TASK_OPS, 'INFO', + "no PR found for branch %s", feature_branch_name) + # the method was called because the state of the task is PULL_REQUEST in the default branch + # however, it's weird that the PR was not found for the feature branch + # TODO: may create or update an issue for the task or deployment + return TaskState.PULL_REQUEST + + return TaskState.PULL_REQUEST + + @log_function_entry_exit() + def _perform_task_action(self) -> bool: + """Perform the task action""" + # TODO: support other actions than ADD + if self.action == EESSITaskAction.ADD: + return self._perform_task_add() + else: + raise ValueError(f"Task action '{self.action}' not supported (yet)") + + @log_function_entry_exit() + def _issue_exists(self, title: str, state: str = 'open') -> bool: + """ + Check if an issue with the given title and state already exists. + """ + issues = self.git_repo.get_issues(state=state) + for issue in issues: + if issue.title == title and issue.state == state: + return True + else: + return False + + @log_function_entry_exit() + def _perform_task_add(self) -> bool: + """Perform the ADD task action""" + # TODO: verify checksum here or before? + script = self.config['paths']['ingestion_script'] + sudo = ['sudo'] if self.config['cvmfs'].getboolean('ingest_as_root', True) else [] + log_message(LoggingScope.STATE_OPS, 'INFO', + 'Running the ingestion script for %s...\n with script: %s\n with sudo: %s', + self.description.get_task_file_name(), + script, 'no' if sudo == [] else 'yes') + ingest_cmd = subprocess.run( + sudo + [script, self.cvmfs_repo, str(self.payload.payload_object.local_file_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + log_message(LoggingScope.STATE_OPS, 'INFO', + 'Ingestion script returned code %s', ingest_cmd.returncode) + log_message(LoggingScope.STATE_OPS, 'INFO', + 'Ingestion script stdout: %s', ingest_cmd.stdout.decode('UTF-8')) + log_message(LoggingScope.STATE_OPS, 'INFO', + 'Ingestion script stderr: %s', ingest_cmd.stderr.decode('UTF-8')) + if ingest_cmd.returncode == 0: + next_state = self._next_state(TaskState.APPROVED) + self._update_task_state_file(next_state) + if self.config.has_section('slack') and self.config['slack'].getboolean('ingestion_notification', False): + send_slack_message( + self.config['secrets']['slack_webhook'], + self.config['slack']['ingestion_message'].format( + tarball=os.path.basename(self.payload.payload_object.local_file_path), + cvmfs_repo=self.cvmfs_repo) + ) + return True + else: + tarball = os.path.basename(self.payload.payload_object.local_file_path) + log_message(LoggingScope.STATE_OPS, 'ERROR', + 'Failed to add %s, return code %s', + tarball, + ingest_cmd.returncode) + + issue_title = f'Failed to add {tarball}' + log_message(LoggingScope.STATE_OPS, 'INFO', + "Creating issue for failed ingestion: title: '%s'", + issue_title) + + command = ' '.join(ingest_cmd.args) + failed_ingestion_issue_body = self.config['github']['failed_ingestion_issue_body'] + issue_body = failed_ingestion_issue_body.format( + command=command, + tarball=tarball, + return_code=ingest_cmd.returncode, + stdout=ingest_cmd.stdout.decode('UTF-8'), + stderr=ingest_cmd.stderr.decode('UTF-8') + ) + log_message(LoggingScope.STATE_OPS, 'INFO', + "Creating issue for failed ingestion: body: '%s'", + issue_body) + + if self._issue_exists(issue_title, state='open'): + log_message(LoggingScope.STATE_OPS, 'INFO', + 'Failed to add %s, but an open issue already exists, skipping...', + os.path.basename(self.payload.payload_object.local_file_path)) + else: + log_message(LoggingScope.STATE_OPS, 'INFO', + 'Failed to add %s, but an open issue does not exist, creating one...', + os.path.basename(self.payload.payload_object.local_file_path)) + self.git_repo.create_issue(title=issue_title, body=issue_body) + return False + + @log_function_entry_exit() + def _handle_add_approved(self): + """Handler for ADD action in APPROVED state""" + print("Handling ADD action in APPROVED state: %s" % self.description.get_task_file_name()) + # Implementation for adding in APPROVED state + # If successful, _perform_task_action() will change the state + # to INGESTED on GitHub + try: + if self._perform_task_action(): + return TaskState.INGESTED + else: + return TaskState.APPROVED + except Exception as err: + log_message(LoggingScope.TASK_OPS, 'ERROR', + "Error performing task action: '%s'\nTraceback:\n%s", err, traceback.format_exc()) + return TaskState.APPROVED + + @log_function_entry_exit() + def _handle_add_ingested(self): + """Handler for ADD action in INGESTED state""" + print("Handling ADD action in INGESTED state: %s" % self.description.get_task_file_name()) + # Implementation for adding in INGESTED state + # DONT change state on GitHub, because the result + # (INGESTED/REJECTED) would be overwritten + return TaskState.DONE + + @log_function_entry_exit() + def _handle_add_rejected(self): + """Handler for ADD action in REJECTED state""" + print("Handling ADD action in REJECTED state: %s" % self.description.get_task_file_name()) + # Implementation for adding in REJECTED state + # DONT change state on GitHub, because the result + # (INGESTED/REJECTED) would be overwritten + return TaskState.DONE + + @log_function_entry_exit() + def __str__(self): + return f"EESSITask(description={self.description}, action={self.action}, state={self.determine_state()})" diff --git a/scripts/automated_ingestion/eessi_task_action.py b/scripts/automated_ingestion/eessi_task_action.py new file mode 100644 index 00000000..6f141435 --- /dev/null +++ b/scripts/automated_ingestion/eessi_task_action.py @@ -0,0 +1,12 @@ +from enum import Enum, auto + + +class EESSITaskAction(Enum): + NOP = auto() # perform no action + DELETE = auto() # perform a delete operation + ADD = auto() # perform an add operation + UPDATE = auto() # perform an update operation + UNKNOWN = auto() # unknown action + + def __str__(self): + return self.name.lower() diff --git a/scripts/automated_ingestion/eessi_task_description.py b/scripts/automated_ingestion/eessi_task_description.py new file mode 100644 index 00000000..5ff4c196 --- /dev/null +++ b/scripts/automated_ingestion/eessi_task_description.py @@ -0,0 +1,188 @@ +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Tuple + +from eessi_data_object import EESSIDataAndSignatureObject +from utils import log_function_entry_exit, log_message, LoggingScope +from remote_storage import DownloadMode + + +@dataclass +class EESSITaskDescription: + """Class representing an EESSI task to be performed, including its metadata and associated data files.""" + + # The EESSI data and signature object associated with this task + task_object: EESSIDataAndSignatureObject + + # Whether the signature was successfully verified + signature_verified: bool = False + + # Metadata from the task description file + metadata: Dict[str, Any] = None + + # task element + task: Dict[str, Any] = None + + # source element + source: Dict[str, Any] = None + + @log_function_entry_exit() + def __init__(self, task_object: EESSIDataAndSignatureObject): + """ + Initialize an EESSITaskDescription object. + + Args: + task_object: The EESSI data and signature object associated with this task + """ + self.task_object = task_object + self.metadata = {} + + self.task_object.download(mode=DownloadMode.CHECK_REMOTE) + + # Verify signature and set initial state + self.signature_verified = self.task_object.verify_signature() + + # Try to read metadata (will only succeed if signature is verified) + try: + self._read_metadata() + except RuntimeError: + # Expected if signature is not verified yet + pass + + # TODO: Process the task file contents + # check if the task file contains a task field and add that to self + if 'task' in self.metadata: + self.task = self.metadata['task'] + else: + self.task = None + + # check if the task file contains a link2pr field and add that to source element + if 'link2pr' in self.metadata: + self.source = self.metadata['link2pr'] + else: + self.source = None + + @log_function_entry_exit() + def _read_metadata(self) -> None: + """ + Internal method to read and parse the metadata from the task description file. + Only reads metadata if the signature has been verified. + """ + if not self.signature_verified: + log_message(LoggingScope.ERROR, 'ERROR', "Cannot read metadata: signature not verified for %s", + self.task_object.local_file_path) + raise RuntimeError("Cannot read metadata: signature not verified") + + try: + with open(self.task_object.local_file_path, 'r') as file: + self.raw_contents = file.read() + self.metadata = json.loads(self.raw_contents) + log_message(LoggingScope.DEBUG, 'DEBUG', "Successfully read metadata from %s", + self.task_object.local_file_path) + except json.JSONDecodeError as err: + log_message(LoggingScope.ERROR, 'ERROR', "Failed to parse JSON in task description file %s: %s", + self.task_object.local_file_path, str(err)) + raise + except Exception as err: + log_message(LoggingScope.ERROR, 'ERROR', "Failed to read task description file %s: %s", + self.task_object.local_file_path, str(err)) + raise + + @log_function_entry_exit() + def get_contents(self) -> str: + """ + Get the contents of the task description / metadata file. + """ + return self.raw_contents + + @log_function_entry_exit() + def get_metadata_file_components(self) -> Tuple[str, str, str, str, str, str]: + """ + Get the components of the metadata file name. + + An example of the metadata file name is: + eessi-2023.06-software-linux-x86_64-amd-zen2-1745557626.tar.gz.meta.txt + + The components are: + eessi: some prefix + VERSION: 2023.06 + COMPONENT: software + OS: linux + ARCHITECTURE: x86_64-amd-zen2 + TIMESTAMP: 1745557626 + SUFFIX: tar.gz.meta.txt + + The ARCHITECTURE component can include one to two hyphens. + The SUFFIX is the part after the first dot (no other components should include dots). + """ + # obtain file name from local file path using basename + file_name = Path(self.task_object.local_file_path).name + # split file_name into part before suffix and the suffix + # idea: split on last hyphen, then split on first dot + suffix = file_name.split('-')[-1].split('.', 1)[1] + file_name_without_suffix = file_name.strip(f".{suffix}") + # from file_name_without_suffix determine VERSION (2nd element), COMPONENT (3rd element), OS (4th element), + # ARCHITECTURE (5th to second last elements) and TIMESTAMP (last element) + components = file_name_without_suffix.split('-') + version = components[1] + component = components[2] + os = components[3] + architecture = '-'.join(components[4:-1]) + timestamp = components[-1] + return version, component, os, architecture, timestamp, suffix + + @log_function_entry_exit() + def get_metadata_value(self, key: str) -> str: + """ + Get the value of a key from the task description / metadata file. + """ + # check that key is defined and has a length > 0 + if not key or len(key) == 0: + raise ValueError("get_metadata_value: key is not defined or has a length of 0") + + value = None + task = self.task + source = self.source + # check if key is in task or source + if task and key in task: + value = task[key] + log_message(LoggingScope.TASK_OPS, 'INFO', + f"Value '{value}' for key '{key}' found in information from task metadata: {task}") + elif source and key in source: + value = source[key] + log_message(LoggingScope.TASK_OPS, 'INFO', + f"Value '{value}' for key '{key}' found in information from source metadata: {source}") + else: + log_message(LoggingScope.TASK_OPS, 'INFO', + f"Value for key '{key}' neither found in task metadata nor source metadata") + raise ValueError(f"Value for key '{key}' neither found in task metadata nor source metadata") + return value + + @log_function_entry_exit() + def get_pr_number(self) -> str: + """ + Get the PR number from the task description / metadata file. + """ + return self.get_metadata_value('pr') + + @log_function_entry_exit() + def get_repo_name(self) -> str: + """ + Get the repository name from the task description / metadata file. + """ + return self.get_metadata_value('repo') + + @log_function_entry_exit() + def get_task_file_name(self) -> str: + """ + Get the file name from the task description / metadata file. + """ + # get file name from remote file path using basename + file_name = Path(self.task_object.remote_file_path).name + return file_name + + @log_function_entry_exit() + def __str__(self) -> str: + """Return a string representation of the EESSITaskDescription object.""" + return f"EESSITaskDescription({self.task_object.local_file_path}, verified={self.signature_verified})" diff --git a/scripts/automated_ingestion/eessi_task_payload.py b/scripts/automated_ingestion/eessi_task_payload.py new file mode 100644 index 00000000..cb39cc81 --- /dev/null +++ b/scripts/automated_ingestion/eessi_task_payload.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +import tarfile +from pathlib import PurePosixPath +import os +from typing import Dict + +from eessi_data_object import EESSIDataAndSignatureObject +from utils import log_function_entry_exit +from remote_storage import DownloadMode + + +@dataclass +class EESSITaskPayload: + """Class representing an EESSI task payload (tarball/artifact) and its signature.""" + + # The EESSI data and signature object associated with this payload + payload_object: EESSIDataAndSignatureObject + + # Whether the signature was successfully verified + signature_verified: bool = False + + # possibly at a later point in time, we will add inferred metadata here + # such as the prefix in a tarball, the main elements, or which software + # package it includes + + @log_function_entry_exit() + def __init__(self, payload_object: EESSIDataAndSignatureObject): + """ + Initialize an EESSITaskPayload object. + + Args: + payload_object: The EESSI data and signature object associated with this payload + """ + self.payload_object = payload_object + + # Download the payload and its signature + self.payload_object.download(mode=DownloadMode.CHECK_REMOTE) + + # Verify signature + self.signature_verified = self.payload_object.verify_signature() + + @log_function_entry_exit() + def analyse_contents(self, config: Dict) -> str: + """Analyse the contents of the payload and return a summary in a ready-to-use HTML format.""" + tar = tarfile.open(self.payload_object.local_file_path, 'r') + members = tar.getmembers() + tar_num_members = len(members) + paths = sorted([m.path for m in members]) + + if tar_num_members < 100: + tar_members_desc = "Full listing of the contents of the tarball:" + members_list = paths + + else: + tar_members_desc = "Summarized overview of the contents of the tarball:" + # determine prefix after filtering out '/init' subdirectory, + # to get actual prefix for specific CPU target (like '2023.06/software/linux/aarch64/neoverse_v1') + init_subdir = os.path.join('*', 'init') + non_init_paths = sorted( + [path for path in paths if not any(parent.match(init_subdir) for parent in PurePosixPath(path).parents)] + ) + if non_init_paths: + prefix = os.path.commonprefix(non_init_paths) + else: + prefix = os.path.commonprefix(paths) + + # TODO: this only works for software tarballs, how to handle compat layer tarballs? + swdirs = [ # all directory names with the pattern: /software// + member.path + for member in members + if member.isdir() and PurePosixPath(member.path).match(os.path.join(prefix, 'software', '*', '*')) + ] + modfiles = [ # all filenames with the pattern: /modules///*.lua + member.path + for member in members + if member.isfile() and + PurePosixPath(member.path).match(os.path.join(prefix, 'modules', '*', '*', '*.lua')) + ] + other = [ # anything that is not in /software nor /modules + member.path + for member in members + if (not PurePosixPath(prefix).joinpath('software') in PurePosixPath(member.path).parents + and not PurePosixPath(prefix).joinpath('modules') in PurePosixPath(member.path).parents) + # if not fnmatch.fnmatch(m.path, os.path.join(prefix, 'software', '*')) + # and not fnmatch.fnmatch(m.path, os.path.join(prefix, 'modules', '*')) + ] + members_list = sorted(swdirs + modfiles + other) + + # Construct the overview + overview = config['github']['task_summary_payload_overview_template'].format( + tar_num_members=tar_num_members, + bucket_url=self.payload_object.remote_client.get_bucket_url(), + remote_file_path=self.payload_object.remote_file_path, + tar_members_desc=tar_members_desc, + tar_members='\n'.join(members_list) + ) + + # Make sure that the overview does not exceed Github's maximum length (65536 characters) + if len(overview) > 60000: + overview = overview[:60000] + "\n\nWARNING: output exceeded the maximum length and was truncated!\n```" + return overview + + @log_function_entry_exit() + def __str__(self) -> str: + """Return a string representation of the EESSITaskPayload object.""" + return f"EESSITaskPayload({self.payload_object.local_file_path}, verified={self.signature_verified})" diff --git a/scripts/automated_ingestion/eessitarball.py b/scripts/automated_ingestion/eessitarball.py index 40ac6fa1..cc3c4ae4 100644 --- a/scripts/automated_ingestion/eessitarball.py +++ b/scripts/automated_ingestion/eessitarball.py @@ -1,11 +1,9 @@ -from utils import send_slack_message, sha256sum +from utils import send_slack_message, sha256sum, log_function_entry_exit, log_message, LoggingScope from pathlib import PurePosixPath -import boto3 import github import json -import logging import os import subprocess import tarfile @@ -19,7 +17,8 @@ class EessiTarball: for which it interfaces with the S3 bucket, GitHub, and CVMFS. """ - def __init__(self, object_name, config, git_staging_repo, s3, bucket, cvmfs_repo): + @log_function_entry_exit() + def __init__(self, object_name, config, git_staging_repo, s3_bucket, cvmfs_repo): """Initialize the tarball object.""" self.config = config self.git_repo = git_staging_repo @@ -27,15 +26,14 @@ def __init__(self, object_name, config, git_staging_repo, s3, bucket, cvmfs_repo self.metadata_sig_file = self.metadata_file + config['signatures']['signature_file_extension'] self.object = object_name self.object_sig = object_name + config['signatures']['signature_file_extension'] - self.s3 = s3 - self.bucket = bucket + self.s3_bucket = s3_bucket self.cvmfs_repo = cvmfs_repo self.local_path = os.path.join(config['paths']['download_dir'], os.path.basename(object_name)) self.local_sig_path = self.local_path + config['signatures']['signature_file_extension'] self.local_metadata_path = self.local_path + config['paths']['metadata_file_extension'] self.local_metadata_sig_path = self.local_metadata_path + config['signatures']['signature_file_extension'] self.sig_verified = False - self.url = f'https://{bucket}.s3.amazonaws.com/{object_name}' + self.url = f'https://{s3_bucket.bucket}.s3.amazonaws.com/{object_name}' self.states = { 'new': {'handler': self.mark_new_tarball_as_staged, 'next_state': 'staged'}, @@ -49,6 +47,7 @@ def __init__(self, object_name, config, git_staging_repo, s3, bucket, cvmfs_repo # Find the initial state of this tarball. self.state = self.find_state() + @log_function_entry_exit() def download(self, force=False): """ Download this tarball and its corresponding metadata file, if this hasn't been already done. @@ -57,32 +56,41 @@ def download(self, force=False): (self.object, self.local_path, self.object_sig, self.local_sig_path), (self.metadata_file, self.local_metadata_path, self.metadata_sig_file, self.local_metadata_sig_path), ] + log_message(LoggingScope.DOWNLOAD, 'INFO', "Downloading %s", files) skip = False for (object, local_file, sig_object, local_sig_file) in files: if force or not os.path.exists(local_file): # First we try to download signature file, which may or may not be available # and may be optional or required. try: - self.s3.download_file(self.bucket, sig_object, local_sig_file) - except: + log_msg = "Downloading signature file %s to %s" + log_message(LoggingScope.DOWNLOAD, 'INFO', log_msg, sig_object, local_sig_file) + self.s3_bucket.download_file(self.s3_bucket.bucket, sig_object, local_sig_file) + except Exception as err: + log_msg = 'Failed to download signature file %s for %s from %s to %s.' if self.config['signatures'].getboolean('signatures_required', True): - logging.error( - f'Failed to download signature file {sig_object} for {object} from {self.bucket} to {local_sig_file}.' + log_msg += '\nException: %s' + log_message( + LoggingScope.ERROR, 'ERROR', log_msg, + sig_object, object, self.s3_bucket.bucket, local_sig_file, err ) skip = True break else: - logging.warning( - f'Failed to download signature file {sig_object} for {object} from {self.bucket} to {local_sig_file}. ' + - 'Ignoring this, because signatures are not required with the current configuration.' + log_msg += ' Ignoring this, because signatures are not required' + log_msg += ' with the current configuration.' + log_msg += '\nException: %s' + log_message( + LoggingScope.DOWNLOAD, 'WARNING', log_msg, + sig_object, object, self.s3_bucket.bucket, local_sig_file, err ) # Now we download the file itself. try: - self.s3.download_file(self.bucket, object, local_file) - except: - logging.error( - f'Failed to download {object} from {self.bucket} to {local_file}.' - ) + log_message(LoggingScope.DOWNLOAD, 'INFO', "Downloading file %s to %s", object, local_file) + self.s3_bucket.download_file(self.s3_bucket.bucket, object, local_file) + except Exception as err: + log_msg = 'Failed to download %s from %s to %s.\nException: %s' + log_message(LoggingScope.ERROR, 'ERROR', log_msg, object, self.s3_bucket.bucket, local_file, err) skip = True break # If any required download failed, make sure to skip this tarball completely. @@ -90,27 +98,30 @@ def download(self, force=False): self.local_path = None self.local_metadata_path = None + @log_function_entry_exit() def find_state(self): """Find the state of this tarball by searching through the state directories in the git repository.""" + log_message(LoggingScope.DEBUG, 'DEBUG', "Find state for %s", self.object) for state in list(self.states.keys()): - # iterate through the state dirs and try to find the tarball's metadata file try: self.git_repo.get_contents(state + '/' + self.metadata_file) + log_msg = "Found metadata file %s in state: %s" + log_message(LoggingScope.STATE_OPS, 'INFO', log_msg, self.metadata_file, state) return state except github.UnknownObjectException: # no metadata file found in this state's directory, so keep searching... continue - except github.GithubException as e: - if e.status == 404: + except github.GithubException as err: + if err.status == 404: # no metadata file found in this state's directory, so keep searching... continue else: # if there was some other (e.g. connection) issue, abort the search for this tarball - logging.warning(f'Unable to determine the state of {self.object}, the GitHub API returned status {e.status}!') + log_msg = 'Unable to determine the state of %s, the GitHub API returned status %s!' + log_message(LoggingScope.ERROR, 'WARNING', log_msg, self.object, err.status) return "unknown" - else: - # if no state was found, we assume this is a new tarball that was ingested to the bucket - return "new" + log_message(LoggingScope.STATE_OPS, 'INFO', "Tarball %s is new", self.metadata_file) + return "new" def get_contents_overview(self): """Return an overview of what is included in the tarball.""" @@ -128,7 +139,9 @@ def get_contents_overview(self): # determine prefix after filtering out '/init' subdirectory, # to get actual prefix for specific CPU target (like '2023.06/software/linux/aarch64/neoverse_v1') init_subdir = os.path.join('*', 'init') - non_init_paths = sorted([p for p in paths if not any(x.match(init_subdir) for x in PurePosixPath(p).parents)]) + non_init_paths = sorted( + [p for p in paths if not any(x.match(init_subdir) for x in PurePosixPath(p).parents)] + ) if non_init_paths: prefix = os.path.commonprefix(non_init_paths) else: @@ -148,8 +161,8 @@ def get_contents_overview(self): other = [ # anything that is not in /software nor /modules m.path for m in members - if not PurePosixPath(prefix).joinpath('software') in PurePosixPath(m.path).parents - and not PurePosixPath(prefix).joinpath('modules') in PurePosixPath(m.path).parents + if (not PurePosixPath(prefix).joinpath('software') in PurePosixPath(m.path).parents + and not PurePosixPath(prefix).joinpath('modules') in PurePosixPath(m.path).parents) # if not fnmatch.fnmatch(m.path, os.path.join(prefix, 'software', '*')) # and not fnmatch.fnmatch(m.path, os.path.join(prefix, 'modules', '*')) ] @@ -181,88 +194,121 @@ def run_handler(self): handler = self.states[self.state]['handler'] handler() + def to_string(self, oneline=False): + """Serialize tarball info so it can be printed.""" + str = f"tarball: {self.object}" + sep = "\n" if not oneline else "," + str += f"{sep} metadt: {self.metadata_file}" + str += f"{sep} bucket: {self.s3_bucket.bucket}" + str += f"{sep} cvmfs.: {self.cvmfs_repo}" + str += f"{sep} GHrepo: {self.git_repo}" + return str + + @log_function_entry_exit() def verify_signatures(self): - """Verify the signatures of the downloaded tarball and metadata file using the corresponding signature files.""" - + """ + Verify the signatures of the downloaded tarball and metadata file + using the corresponding signature files. + """ sig_missing_msg = 'Signature file %s is missing.' sig_missing = False for sig_file in [self.local_sig_path, self.local_metadata_sig_path]: if not os.path.exists(sig_file): - logging.warning(sig_missing_msg % sig_file) + log_message(LoggingScope.VERIFICATION, 'WARNING', sig_missing_msg, sig_file) sig_missing = True + log_message(LoggingScope.VERIFICATION, 'INFO', "Signature file %s is missing.", sig_file) if sig_missing: # If signature files are missing, we return a failure, # unless the configuration specifies that signatures are not required. if self.config['signatures'].getboolean('signatures_required', True): + log_message(LoggingScope.ERROR, 'ERROR', "Signature file %s is missing.", sig_file) return False else: + log_msg = "Signature file %s is missing, but signatures are not required." + log_message(LoggingScope.VERIFICATION, 'INFO', log_msg, sig_file) return True # If signatures are provided, we should always verify them, regardless of the signatures_required. # In order to do so, we need the verification script and an allowed signers file. + verify_runenv = self.config['signatures']['signature_verification_runenv'].split() verify_script = self.config['signatures']['signature_verification_script'] allowed_signers_file = self.config['signatures']['allowed_signers_file'] if not os.path.exists(verify_script): - logging.error(f'Unable to verify signatures, the specified signature verification script does not exist!') + log_msg = 'Unable to verify signatures, the specified signature verification script does not exist!' + log_message(LoggingScope.ERROR, 'ERROR', log_msg) return False if not os.path.exists(allowed_signers_file): - logging.error(f'Unable to verify signatures, the specified allowed signers file does not exist!') + log_msg = 'Unable to verify signatures, the specified allowed signers file does not exist!' + log_message(LoggingScope.ERROR, 'ERROR', log_msg) return False - for (file, sig_file) in [(self.local_path, self.local_sig_path), (self.local_metadata_path, self.local_metadata_sig_path)]: + for (file, sig_file) in [ + (self.local_path, self.local_sig_path), + (self.local_metadata_path, self.local_metadata_sig_path) + ]: + command = verify_runenv + [verify_script, '--verify', '--allowed-signers-file', allowed_signers_file, + '--file', file, '--signature-file', sig_file] + log_message(LoggingScope.VERIFICATION, 'INFO', "Running command: %s", ' '.join(command)) + verify_cmd = subprocess.run( - [verify_script, '--verify', '--allowed-signers-file', allowed_signers_file, '--file', file, '--signature-file', sig_file], + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if verify_cmd.returncode == 0: - logging.debug(f'Signature for {file} successfully verified.') + log_message(LoggingScope.VERIFICATION, 'DEBUG', 'Signature for %s successfully verified.', file) else: - logging.error(f'Failed to verify signature for {file}.') + log_message(LoggingScope.ERROR, 'ERROR', 'Failed to verify signature for %s.', file) + log_message(LoggingScope.ERROR, 'ERROR', " stdout: %s", verify_cmd.stdout.decode('UTF-8')) + log_message(LoggingScope.ERROR, 'ERROR', " stderr: %s", verify_cmd.stderr.decode('UTF-8')) return False self.sig_verified = True return True + @log_function_entry_exit() def verify_checksum(self): """Verify the checksum of the downloaded tarball with the one in its metadata file.""" local_sha256 = sha256sum(self.local_path) meta_sha256 = None with open(self.local_metadata_path, 'r') as meta: meta_sha256 = json.load(meta)['payload']['sha256sum'] - logging.debug(f'Checksum of downloaded tarball: {local_sha256}') - logging.debug(f'Checksum stored in metadata file: {meta_sha256}') + log_message(LoggingScope.VERIFICATION, 'DEBUG', 'Checksum of downloaded tarball: %s', local_sha256) + log_message(LoggingScope.VERIFICATION, 'DEBUG', 'Checksum stored in metadata file: %s', meta_sha256) return local_sha256 == meta_sha256 + @log_function_entry_exit() def ingest(self): """Process a tarball that is ready to be ingested by running the ingestion script.""" - #TODO: check if there is an open issue for this tarball, and if there is, skip it. - logging.info(f'Tarball {self.object} is ready to be ingested.') + # TODO: check if there is an open issue for this tarball, and if there is, skip it. + log_message(LoggingScope.STATE_OPS, 'INFO', 'Tarball %s is ready to be ingested.', self.object) self.download() - logging.info('Verifying its signature...') + log_message(LoggingScope.VERIFICATION, 'INFO', 'Verifying its signature...') if not self.verify_signatures(): issue_msg = f'Failed to verify signatures for `{self.object}`' - logging.error(issue_msg) + log_message(LoggingScope.ERROR, 'ERROR', issue_msg) if not self.issue_exists(issue_msg, state='open'): self.git_repo.create_issue(title=issue_msg, body=issue_msg) return else: - logging.debug(f'Signatures of {self.object} and its metadata file successfully verified.') + log_msg = 'Signatures of %s and its metadata file successfully verified.' + log_message(LoggingScope.VERIFICATION, 'DEBUG', log_msg, self.object) - logging.info('Verifying its checksum...') + log_message(LoggingScope.VERIFICATION, 'INFO', 'Verifying its checksum...') if not self.verify_checksum(): issue_msg = f'Failed to verify checksum for `{self.object}`' - logging.error(issue_msg) + log_message(LoggingScope.ERROR, 'ERROR', issue_msg) if not self.issue_exists(issue_msg, state='open'): self.git_repo.create_issue(title=issue_msg, body=issue_msg) return else: - logging.debug(f'Checksum of {self.object} matches the one in its metadata file.') + log_msg = 'Checksum of %s matches the one in its metadata file.' + log_message(LoggingScope.VERIFICATION, 'DEBUG', log_msg, self.object) script = self.config['paths']['ingestion_script'] sudo = ['sudo'] if self.config['cvmfs'].getboolean('ingest_as_root', True) else [] - logging.info(f'Running the ingestion script for {self.object}...') + log_message(LoggingScope.STATE_OPS, 'INFO', 'Running the ingestion script for %s...', self.object) ingest_cmd = subprocess.run( sudo + [script, self.cvmfs_repo, self.local_path], stdout=subprocess.PIPE, @@ -273,7 +319,9 @@ def ingest(self): if self.config.has_section('slack') and self.config['slack'].getboolean('ingestion_notification', False): send_slack_message( self.config['secrets']['slack_webhook'], - self.config['slack']['ingestion_message'].format(tarball=os.path.basename(self.object), cvmfs_repo=self.cvmfs_repo) + self.config['slack']['ingestion_message'].format( + tarball=os.path.basename(self.object), + cvmfs_repo=self.cvmfs_repo) ) else: issue_title = f'Failed to ingest {self.object}' @@ -285,133 +333,262 @@ def ingest(self): stderr=ingest_cmd.stderr.decode('UTF-8'), ) if self.issue_exists(issue_title, state='open'): - logging.info(f'Failed to ingest {self.object}, but an open issue already exists, skipping...') + log_msg = 'Failed to ingest %s, but an open issue already exists, skipping...' + log_message(LoggingScope.STATE_OPS, 'INFO', log_msg, self.object) else: self.git_repo.create_issue(title=issue_title, body=issue_body) def print_ingested(self): """Process a tarball that has already been ingested.""" - logging.info(f'{self.object} has already been ingested, skipping...') + log_message(LoggingScope.STATE_OPS, 'INFO', '%s has already been ingested, skipping...', self.object) - def mark_new_tarball_as_staged(self): + @log_function_entry_exit() + def mark_new_tarball_as_staged(self, branch=None): """Process a new tarball that was added to the staging bucket.""" next_state = self.next_state(self.state) - logging.info(f'Found new tarball {self.object}, downloading it...') + log_msg = 'Found new tarball %s, downloading it...' + log_message(LoggingScope.STATE_OPS, 'INFO', log_msg, self.object) # Download the tarball and its metadata file. # Use force as it may be a new attempt for an existing tarball that failed before. self.download(force=True) if not self.local_path or not self.local_metadata_path: - logging.warn('Skipping this tarball...') + log_msg = "Skipping tarball %s - download failed" + log_message(LoggingScope.STATE_OPS, 'WARNING', log_msg, self.object) return # Verify the signatures of the tarball and metadata file. if not self.verify_signatures(): - logging.warn('Signature verification of the tarball or its metadata failed, skipping this tarball...') + log_msg = "Skipping tarball %s - signature verification failed" + log_message(LoggingScope.STATE_OPS, 'WARNING', log_msg, self.object) + return + # If no branch is provided, use the main branch + target_branch = branch if branch else 'main' + log_msg = "Adding metadata to '%s' folder in %s branch" + log_message(LoggingScope.STATE_OPS, 'INFO', log_msg, next_state, target_branch) + + file_path_staged = next_state + '/' + self.metadata_file contents = '' with open(self.local_metadata_path, 'r') as meta: contents = meta.read() - - logging.info(f'Adding tarball\'s metadata to the "{next_state}" folder of the git repository.') - file_path_staged = next_state + '/' + self.metadata_file - new_file = self.git_repo.create_file(file_path_staged, 'new tarball', contents, branch='main') + self.git_repo.create_file(file_path_staged, 'new tarball', contents, branch=target_branch) self.state = next_state - self.run_handler() + if not branch: # Only run handler if we're not part of a group + self.run_handler() def print_rejected(self): """Process a (rejected) tarball for which the corresponding PR has been closed witout merging.""" - logging.info("This tarball was rejected, so we're skipping it.") + log_message(LoggingScope.STATE_OPS, 'INFO', "This tarball was rejected, so we're skipping it.") # Do we want to delete rejected tarballs at some point? def print_unknown(self): """Process a tarball which has an unknown state.""" - logging.info("The state of this tarball could not be determined, so we're skipping it.") + log_msg = "The state of this tarball could not be determined," + log_msg += " so we're skipping it." + log_message(LoggingScope.STATE_OPS, 'INFO', log_msg) + + def find_next_sequence_number(self, repo, pr_id): + """Find the next available sequence number for staging PRs of a source PR.""" + # Search for existing branches for this source PR + base_branch = f'staging-{repo.replace("/", "-")}-pr-{pr_id}-seq-' + existing_branches = [ + ref.ref for ref in self.git_repo.get_git_refs() + if ref.ref.startswith(f'refs/heads/{base_branch}') + ] + + if not existing_branches: + return 1 + + # Extract sequence numbers from existing branches + sequence_numbers = [] + for branch in existing_branches: + try: + # Extract the sequence number from branch name + # Format: staging--pr--seq- + sequence = int(branch.split('-')[-1]) + sequence_numbers.append(sequence) + except (ValueError, IndexError): + continue + + if not sequence_numbers: + return 1 - def make_approval_request(self): + # Return next available sequence number + return max(sequence_numbers) + 1 + + @log_function_entry_exit() + def make_approval_request(self, tarballs_in_group=None): """Process a staged tarball by opening a pull request for ingestion approval.""" next_state = self.next_state(self.state) - file_path_staged = self.state + '/' + self.metadata_file - file_path_to_ingest = next_state + '/' + self.metadata_file + log_msg = "Making approval request for tarball %s in state %s to %s" + log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, self.object, self.state, next_state) + # obtain link2pr information (repo and pr_id) from metadata file + with open(self.local_metadata_path, 'r') as meta: + metadata = meta.read() + meta_dict = json.loads(metadata) + repo, pr_id = meta_dict['link2pr']['repo'], meta_dict['link2pr']['pr'] - filename = os.path.basename(self.object) - tarball_metadata = self.git_repo.get_contents(file_path_staged) - git_branch = filename + '_' + next_state - self.download() + # find next sequence number for staging PRs of this source PR + sequence = self.find_next_sequence_number(repo, pr_id) + git_branch = f'staging-{repo.replace("/", "-")}-pr-{pr_id}-seq-{sequence}' + # Check if git_branch exists and what the status of the corressponding PR is main_branch = self.git_repo.get_branch('main') if git_branch in [branch.name for branch in self.git_repo.get_branches()]: - # Existing branch found for this tarball, so we've run this step before. - # Try to find out if there's already a PR as well... - logging.info("Branch already exists for " + self.object) - # Filtering with only head= returns all prs if there's no match, so double-check - find_pr = [pr for pr in self.git_repo.get_pulls(head=git_branch, state='all') if pr.head.ref == git_branch] - logging.debug('Found PRs: ' + str(find_pr)) + log_msg = "Branch %s already exists, checking the status of the corresponding PR..." + log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, git_branch) + find_pr = [pr for pr in self.git_repo.get_pulls(head=git_branch, state='all') + if pr.head.ref == git_branch] if find_pr: - # So, we have a branch and a PR for this tarball (if there are more, pick the first one)... pr = find_pr.pop(0) - logging.info(f'PR {pr.number} found for {self.object}') if pr.state == 'open': - # The PR is still open, so it hasn't been reviewed yet: ignore this tarball. - logging.info('PR is still open, skipping this tarball...') + log_message(LoggingScope.GITHUB_OPS, 'INFO', 'PR is still open, skipping this tarball...') return elif pr.state == 'closed' and not pr.merged: - # The PR was closed but not merged, i.e. it was rejected for ingestion. - logging.info('PR was rejected') + log_message(LoggingScope.GITHUB_OPS, 'INFO', 'PR was rejected') self.reject() return else: - logging.warn(f'Warning, tarball {self.object} is in a weird state:') - logging.warn(f'Branch: {git_branch}\nPR: {pr}\nPR state: {pr.state}\nPR merged: {pr.merged}') + log_msg = 'Warning, tarball %s is in a weird state:' + log_message(LoggingScope.GITHUB_OPS, 'WARNING', log_msg, self.object) + log_msg = 'Branch: %s\nPR: %s\nPR state: %s\nPR merged: %s' + log_message(LoggingScope.GITHUB_OPS, 'WARNING', log_msg, + git_branch, pr, pr.state, pr.merged) + # TODO: should we delete the branch or open an issue? + return else: - # There is a branch, but no PR for this tarball. - # This is weird, so let's remove the branch and reprocess the tarball. - logging.info(f'Tarball {self.object} has a branch, but no PR.') - logging.info(f'Removing existing branch...') + log_msg = 'Tarball %s has a branch, but no PR.' + log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, self.object) + log_message(LoggingScope.GITHUB_OPS, 'INFO', 'Removing existing branch...') ref = self.git_repo.get_git_ref(f'heads/{git_branch}') ref.delete() - logging.info(f'Making pull request to get ingestion approval for {self.object}.') - # Create a new branch + + # Create new branch self.git_repo.create_git_ref(ref='refs/heads/' + git_branch, sha=main_branch.commit.sha) - # Move the file to the directory of the next stage in this branch - self.move_metadata_file(self.state, next_state, branch=git_branch) - # Get metadata file contents - metadata = '' - with open(self.local_metadata_path, 'r') as meta: - metadata = meta.read() - meta_dict = json.loads(metadata) - repo, pr_id = meta_dict['link2pr']['repo'], meta_dict['link2pr']['pr'] - pr_url = f"https://github.com/{repo}/pull/{pr_id}" - # Try to get the tarball contents and open a PR to get approval for the ingestion + + # Move metadata file(s) to approved directory + log_msg = "Moving metadata for %s from %s to %s in branch %s" + log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, + self.object, self.state, next_state, git_branch) + if tarballs_in_group is None: + log_message(LoggingScope.GITHUB_OPS, 'INFO', "Moving metadata for individual tarball to staged") + self.move_metadata_file(self.state, next_state, branch=git_branch) + else: + log_msg = "Moving metadata for %d tarballs to staged" + log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, len(tarballs_in_group)) + for tarball in tarballs_in_group: + temp_tar = EessiTarball(tarball, self.config, self.git_repo, self.s3_bucket, self.cvmfs_repo) + temp_tar.move_metadata_file(self.state, next_state, branch=git_branch) + + # Create PR with appropriate template try: - tarball_contents = self.get_contents_overview() - pr_body = self.config['github']['pr_body'].format( - cvmfs_repo=self.cvmfs_repo, - pr_url=pr_url, - tar_overview=self.get_contents_overview(), - metadata=metadata, - ) - pr_title = '[%s] Ingest %s' % (self.cvmfs_repo, filename) + pr_url = f"https://github.com/{repo}/pull/{pr_id}" + if tarballs_in_group is None: + log_msg = "Creating PR for individual tarball: %s" + log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, self.object) + tarball_contents = self.get_contents_overview() + pr_body = self.config['github']['individual_pr_body'].format( + cvmfs_repo=self.cvmfs_repo, + pr_url=pr_url, + tar_overview=tarball_contents, + metadata=metadata, + ) + pr_title = f'[{self.cvmfs_repo}] Ingest {os.path.basename(self.object)}' + else: + # Group of tarballs + tar_overviews = [] + for tarball in tarballs_in_group: + try: + temp_tar = EessiTarball(tarball, self.config, self.git_repo, self.s3_bucket, self.cvmfs_repo) + temp_tar.download() + overview = temp_tar.get_contents_overview() + tar_details_tpl = "
\nContents of %s\n\n%s\n
\n" + tar_overviews.append(tar_details_tpl % (tarball, overview)) + except Exception as err: + log_msg = "Failed to get contents overview for %s: %s" + log_message(LoggingScope.ERROR, 'ERROR', log_msg, tarball, err) + tar_details_tpl = "
\nContents of %s\n\n" + tar_details_tpl += "Failed to get contents overview: %s\n
\n" + tar_overviews.append(tar_details_tpl % (tarball, err)) + + pr_body = self.config['github']['grouped_pr_body'].format( + cvmfs_repo=self.cvmfs_repo, + pr_url=pr_url, + tarballs=self.format_tarball_list(tarballs_in_group), + metadata=self.format_metadata_list(tarballs_in_group), + tar_overview="\n".join(tar_overviews) + ) + pr_title = f'[{self.cvmfs_repo}] Staging PR #{sequence} for {repo}#{pr_id}' + + # Add signature verification status if applicable if self.sig_verified: - pr_body += "\n\n:heavy_check_mark: :closed_lock_with_key: The signature of this tarball has been successfully verified." + pr_body += "\n\n:heavy_check_mark: :closed_lock_with_key: " + pr_body += "The signature of this tarball has been successfully verified." pr_title += ' :closed_lock_with_key:' + self.git_repo.create_pull(title=pr_title, body=pr_body, head=git_branch, base='main') + log_message(LoggingScope.GITHUB_OPS, 'INFO', "Created PR: %s", pr_title) + except Exception as err: - issue_title = f'Failed to get contents of {self.object}' - issue_body = self.config['github']['failed_tarball_overview_issue_body'].format( - tarball=self.object, - error=err + log_message(LoggingScope.ERROR, 'ERROR', "Failed to create PR: %s", err) + if not self.issue_exists(f'Failed to get contents of {self.object}', state='open'): + self.git_repo.create_issue( + title=f'Failed to get contents of {self.object}', + body=self.config['github']['failed_tarball_overview_issue_body'].format( + tarball=self.object, + error=err + ) + ) + + def format_tarball_list(self, tarballs): + """Format a list of tarballs with checkboxes for approval.""" + formatted = "### Tarballs to be ingested\n\n" + for tarball in tarballs: + formatted += f"- [ ] {tarball}\n" + return formatted + + def format_metadata_list(self, tarballs): + """Format metadata for all tarballs in collapsible sections.""" + formatted = "### Metadata\n\n" + for tarball in tarballs: + with open(self.get_metadata_path(tarball), 'r') as meta: + metadata = meta.read() + formatted += ( + f"
\nMetadata for {tarball}\n\n" + f"```\n{metadata}\n```\n
\n\n" + ) + return formatted + + def get_metadata_path(self, tarball=None): + """ + Return the local path of the metadata file. + + Args: + tarball (str, optional): Name of the tarball to get metadata path for. + If None, use the current tarball's metadata file. + """ + if tarball is None: + # For single tarball, use the instance's metadata file + if not self.local_metadata_path: + self.local_metadata_path = os.path.join( + self.config['paths']['download_dir'], + os.path.basename(self.metadata_file) + ) + return self.local_metadata_path + else: + # For group of tarballs, construct path from tarball name + return os.path.join( + self.config['paths']['download_dir'], + os.path.basename(tarball) + self.config['paths']['metadata_file_extension'] ) - if len([i for i in self.git_repo.get_issues(state='open') if i.title == issue_title]) == 0: - self.git_repo.create_issue(title=issue_title, body=issue_body) - else: - logging.info(f'Failed to create tarball overview, but an issue already exists.') def move_metadata_file(self, old_state, new_state, branch='main'): """Move the metadata file of a tarball from an old state's directory to a new state's directory.""" file_path_old = old_state + '/' + self.metadata_file file_path_new = new_state + '/' + self.metadata_file - logging.debug(f'Moving metadata file {self.metadata_file} from {file_path_old} to {file_path_new}.') + log_message(LoggingScope.GITHUB_OPS, 'INFO', 'Moving metadata file %s from %s to %s in branch %s', + self.metadata_file, file_path_old, file_path_new, branch) tarball_metadata = self.git_repo.get_contents(file_path_old) # Remove the metadata file from the old state's directory... self.git_repo.delete_file(file_path_old, 'remove from ' + old_state, sha=tarball_metadata.sha, branch=branch) @@ -419,10 +596,54 @@ def move_metadata_file(self, old_state, new_state, branch='main'): self.git_repo.create_file(file_path_new, 'move to ' + new_state, tarball_metadata.decoded_content, branch=branch) + def process_pr_merge(self, pr_number): + """Process a merged PR by handling the checkboxes and moving tarballs to appropriate states.""" + pr = self.git_repo.get_pull(pr_number) + + # Get the branch name + branch_name = pr.head.ref + + # Get the list of tarballs from the PR body + tarballs = self.extract_tarballs_from_pr_body(pr.body) + + # Get the checked status for each tarball + checked_tarballs = self.extract_checked_tarballs(pr.body) + + # Process each tarball + for tarball in tarballs: + if tarball in checked_tarballs: + # Move to approved state + self.move_metadata_file('staged', 'approved', branch=branch_name) + else: + # Move to rejected state + self.move_metadata_file('staged', 'rejected', branch=branch_name) + + # Delete the branch after processing + ref = self.git_repo.get_git_ref(f'heads/{branch_name}') + ref.delete() + + def extract_checked_tarballs(self, pr_body): + """Extract list of checked tarballs from PR body.""" + checked_tarballs = [] + for line in pr_body.split('\n'): + if line.strip().startswith('- [x] '): + tarball = line.strip()[6:] # Remove '- [x] ' prefix + checked_tarballs.append(tarball) + return checked_tarballs + + def extract_tarballs_from_pr_body(self, pr_body): + """Extract list of all tarballs from PR body.""" + tarballs = [] + for line in pr_body.split('\n'): + if line.strip().startswith('- ['): + tarball = line.strip()[6:] # Remove '- [ ] ' or '- [x] ' prefix + tarballs.append(tarball) + return tarballs + def reject(self): """Reject a tarball for ingestion.""" # Let's move the the tarball to the directory for rejected tarballs. - logging.info(f'Marking tarball {self.object} as rejected...') + log_message(LoggingScope.STATE_OPS, 'INFO', 'Marking tarball %s as rejected...', self.object) next_state = 'rejected' self.move_metadata_file(self.state, next_state) @@ -434,3 +655,82 @@ def issue_exists(self, title, state='open'): return True else: return False + + def get_link2pr_info(self): + """Get the link2pr information from the metadata file.""" + with open(self.local_metadata_path, 'r') as meta: + metadata = json.load(meta) + return metadata['link2pr']['repo'], metadata['link2pr']['pr'] + + +class EessiTarballGroup: + """Class to handle a group of tarballs that share the same link2pr information.""" + + def __init__(self, first_tarball, config, git_staging_repo, s3_bucket, cvmfs_repo): + """Initialize with the first tarball in the group.""" + self.first_tar = EessiTarball(first_tarball, config, git_staging_repo, s3_bucket, cvmfs_repo) + self.config = config + self.git_repo = git_staging_repo + self.s3_bucket = s3_bucket + self.cvmfs_repo = cvmfs_repo + + def download_tarballs_and_more(self, tarballs): + """Download all files associated with this group of tarballs.""" + for tarball in tarballs: + temp_tar = EessiTarball(tarball, self.config, self.git_repo, self.s3_bucket, self.cvmfs_repo) + log_message(LoggingScope.GROUP_OPS, 'INFO', "downloading files for '%s'", temp_tar.object) + temp_tar.download(force=True) + if not temp_tar.local_path or not temp_tar.local_metadata_path: + log_message(LoggingScope.GROUP_OPS, 'WARNING', "Skipping this tarball: %s", temp_tar.object) + return False + return True + + def process_group(self, tarballs): + """Process a group of tarballs together.""" + log_message(LoggingScope.GROUP_OPS, 'INFO', "Processing group of %d tarballs", len(tarballs)) + + if not self.download_tarballs_and_more(tarballs): + log_msg = "Downloading tarballs, metadata files and/or their signatures failed" + log_message(LoggingScope.ERROR, 'ERROR', log_msg) + return + + # Verify all tarballs have the same link2pr info + if not self.verify_group_consistency(tarballs): + log_message(LoggingScope.ERROR, 'ERROR', "Tarballs have inconsistent link2pr information") + return + + # Mark all tarballs as staged in the group branch, however need to handle first tarball differently + log_msg = "Processing first tarball in group: %s" + log_message(LoggingScope.GROUP_OPS, 'INFO', log_msg, self.first_tar.object) + self.first_tar.mark_new_tarball_as_staged('main') # this sets the state of the first tarball to 'staged' + for tarball in tarballs[1:]: + log_msg = "Processing tarball in group: %s" + log_message(LoggingScope.GROUP_OPS, 'INFO', log_msg, tarball) + temp_tar = EessiTarball(tarball, self.config, self.git_repo, self.s3_bucket, self.cvmfs_repo) + temp_tar.mark_new_tarball_as_staged('main') + + # Process the group for approval, only works correctly if first tarball is already in state 'staged' + self.first_tar.make_approval_request(tarballs) + + def to_string(self, oneline=False): + """Serialize tarball group info so it can be printed.""" + str = f"first tarball: {self.first_tar.to_string(oneline)}" + sep = "\n" if not oneline else "," + str += f"{sep} config: {self.config}" + str += f"{sep} GHrepo: {self.git_repo}" + str += f"{sep} s3....: {self.s3_bucket}" + str += f"{sep} bucket: {self.s3_bucket.bucket}" + str += f"{sep} cvmfs.: {self.cvmfs_repo}" + return str + + def verify_group_consistency(self, tarballs): + """Verify all tarballs in the group have the same link2pr information.""" + first_repo, first_pr = self.first_tar.get_link2pr_info() + + for tarball in tarballs[1:]: # Skip first tarball as we already have its info + temp_tar = EessiTarball(tarball, self.config, self.git_repo, self.s3_bucket, self.cvmfs_repo) + log_message(LoggingScope.DEBUG, 'DEBUG', "temp tar: %s", temp_tar.to_string()) + repo, pr = temp_tar.get_link2pr_info() + if repo != first_repo or pr != first_pr: + return False + return True diff --git a/scripts/automated_ingestion/remote_storage.py b/scripts/automated_ingestion/remote_storage.py new file mode 100644 index 00000000..2a386a7d --- /dev/null +++ b/scripts/automated_ingestion/remote_storage.py @@ -0,0 +1,34 @@ +from enum import Enum +from typing import Protocol, runtime_checkable + + +class DownloadMode(Enum): + """Enum defining different modes for downloading files.""" + FORCE = 'force' # Always download and overwrite + CHECK_REMOTE = 'check-remote' # Download if remote files have changed + CHECK_LOCAL = 'check-local' # Download if files don't exist locally (default) + + +@runtime_checkable +class RemoteStorageClient(Protocol): + """Protocol defining the interface for remote storage clients.""" + + def get_metadata(self, remote_path: str) -> dict: + """Get metadata about a remote object. + + Args: + remote_path: Path to the object in remote storage + + Returns: + Dictionary containing object metadata, including 'ETag' key + """ + ... + + def download(self, remote_path: str, local_path: str) -> None: + """Download a remote file to a local location. + + Args: + remote_path: Path to the object in remote storage + local_path: Local path where to save the file + """ + ... diff --git a/scripts/automated_ingestion/s3_bucket.py b/scripts/automated_ingestion/s3_bucket.py new file mode 100644 index 00000000..79fed289 --- /dev/null +++ b/scripts/automated_ingestion/s3_bucket.py @@ -0,0 +1,187 @@ +import os +from pathlib import Path +from typing import Dict, Optional + +import boto3 +from botocore.exceptions import ClientError +from utils import log_function_entry_exit, log_message, LoggingScope +from remote_storage import RemoteStorageClient + + +class EESSIS3Bucket(RemoteStorageClient): + """EESSI-specific S3 bucket implementation of the RemoteStorageClient protocol.""" + + @log_function_entry_exit() + def __init__(self, config, bucket_name: str): + """ + Initialize the EESSI S3 bucket. + + Args: + config: Configuration object containing: + - aws.access_key_id: AWS access key ID (optional, can use AWS_ACCESS_KEY_ID env var) + - aws.secret_access_key: AWS secret access key (optional, can use AWS_SECRET_ACCESS_KEY env var) + - aws.endpoint_url: Custom endpoint URL for S3-compatible backends (optional) + - aws.verify: SSL verification setting (optional) + - True: Verify SSL certificates (default) + - False: Skip SSL certificate verification + - str: Path to CA bundle file + bucket_name: Name of the S3 bucket to use + """ + self.bucket = bucket_name + + # Get AWS credentials from environment or config + aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID') or config.get('secrets', 'aws_access_key_id') + aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY') or config.get('secrets', 'aws_secret_access_key') + + # Configure boto3 client + client_config = {} + + # Add endpoint URL if specified in config + if config.has_option('aws', 'endpoint_url'): + client_config['endpoint_url'] = config['aws']['endpoint_url'] + log_message(LoggingScope.DEBUG, 'DEBUG', "Using custom endpoint URL: %s", client_config['endpoint_url']) + + # Add SSL verification if specified in config + if config.has_option('aws', 'verify'): + verify = config['aws']['verify'] + if verify.lower() == 'false': + client_config['verify'] = False + log_message(LoggingScope.DEBUG, 'WARNING', "SSL verification disabled") + elif verify.lower() == 'true': + client_config['verify'] = True + else: + client_config['verify'] = verify # Assume it's a path to CA bundle + log_message(LoggingScope.DEBUG, 'DEBUG', "Using custom CA bundle: %s", verify) + + self.client = boto3.client( + 's3', + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + **client_config + ) + log_message(LoggingScope.DEBUG, 'INFO', "Initialized S3 client for bucket: %s", self.bucket) + + def list_objects_v2(self, **kwargs): + """ + List objects in the bucket using the underlying boto3 client. + + Args: + **kwargs: Additional arguments to pass to boto3.client.list_objects_v2 + + Returns: + Response from boto3.client.list_objects_v2 + """ + return self.client.list_objects_v2(Bucket=self.bucket, **kwargs) + + def download_file(self, key: str, filename: str) -> None: + """ + Download a file from S3 to a local file. + + Args: + key: The S3 key of the file to download + filename: The local path where the file should be saved + """ + self.client.download_file(self.bucket, key, filename) + + @log_function_entry_exit() + def get_metadata(self, remote_path: str) -> Dict: + """ + Get metadata about an S3 object. + + Args: + remote_path: Path to the object in S3 + + Returns: + Dictionary containing object metadata, including 'ETag' key + """ + try: + log_message(LoggingScope.DEBUG, 'DEBUG', "Getting metadata for S3 object: %s", remote_path) + response = self.client.head_object(Bucket=self.bucket, Key=remote_path) + log_message(LoggingScope.DEBUG, 'DEBUG', "Retrieved metadata for %s: %s", remote_path, response) + return response + except ClientError as err: + log_message(LoggingScope.ERROR, 'ERROR', "Failed to get metadata for %s: %s", remote_path, str(err)) + raise + + def _get_etag_file_path(self, local_path: str) -> Path: + """Get the path to the .etag file for a given local file.""" + return Path(local_path).with_suffix('.etag') + + def _read_etag(self, local_path: str) -> Optional[str]: + """Read the ETag from the .etag file if it exists.""" + etag_path = self._get_etag_file_path(local_path) + if etag_path.exists(): + try: + with open(etag_path, 'r') as f: + return f.read().strip() + except Exception as e: + log_message(LoggingScope.DEBUG, 'WARNING', "Failed to read ETag file %s: %s", etag_path, str(e)) + return None + return None + + def _write_etag(self, local_path: str, etag: str) -> None: + """Write the ETag to the .etag file.""" + etag_path = self._get_etag_file_path(local_path) + try: + with open(etag_path, 'w') as f: + f.write(etag) + log_message(LoggingScope.DEBUG, 'DEBUG', "Wrote ETag to %s", etag_path) + except Exception as e: + log_message(LoggingScope.ERROR, 'ERROR', "Failed to write ETag file %s: %s", etag_path, str(e)) + # If we can't write the etag file, it's not critical + # The file will just be downloaded again next time + + @log_function_entry_exit() + def download(self, remote_path: str, local_path: str) -> None: + """ + Download an S3 object to a local location and store its ETag. + + Args: + remote_path: Path to the object in S3 + local_path: Local path where to save the file + """ + try: + log_message(LoggingScope.DOWNLOAD, 'INFO', "Downloading %s to %s", remote_path, local_path) + self.client.download_file(Bucket=self.bucket, Key=remote_path, Filename=local_path) + log_message(LoggingScope.DOWNLOAD, 'INFO', "Successfully downloaded %s to %s", remote_path, local_path) + except ClientError as err: + log_message(LoggingScope.ERROR, 'ERROR', "Failed to download %s: %s", remote_path, str(err)) + raise + + # Get metadata first to obtain the ETag + metadata = self.get_metadata(remote_path) + etag = metadata['ETag'] + + # Store the ETag + self._write_etag(local_path, etag) + + @log_function_entry_exit() + def get_bucket_url(self) -> str: + """ + Get the HTTPS URL for a bucket from an initialized boto3 client. + Works with both AWS S3 and MinIO/S3-compatible services. + """ + try: + # Check if this is a custom endpoint (MinIO) or AWS S3 + endpoint_url = self.client.meta.endpoint_url + + if endpoint_url: + # Custom endpoint (MinIO, DigitalOcean Spaces, etc.) + # Most S3-compatible services use path-style URLs + bucket_url = f"{endpoint_url}/{self.bucket}" + + else: + # AWS S3 (no custom endpoint specified) + region = self.client.meta.region_name or 'us-east-1' + + # AWS S3 virtual-hosted-style URLs + if region == 'us-east-1': + bucket_url = f"https://{self.bucket}.s3.amazonaws.com" + else: + bucket_url = f"https://{self.bucket}.s3.{region}.amazonaws.com" + + return bucket_url + + except Exception as err: + log_message(LoggingScope.ERROR, 'ERROR', "Error getting bucket URL: %s", str(err)) + return None diff --git a/scripts/automated_ingestion/utils.py b/scripts/automated_ingestion/utils.py index 66843dd9..4b867764 100644 --- a/scripts/automated_ingestion/utils.py +++ b/scripts/automated_ingestion/utils.py @@ -1,6 +1,104 @@ import hashlib import json import requests +import logging +import functools +import time +import os +import inspect +from enum import IntFlag, auto +import sys + + +class LoggingScope(IntFlag): + """Enumeration of different logging scopes.""" + NONE = 0 + FUNC_ENTRY_EXIT = auto() # Function entry/exit logging + DOWNLOAD = auto() # Logging related to file downloads + VERIFICATION = auto() # Logging related to signature and checksum verification + STATE_OPS = auto() # Logging related to tarball state operations + GITHUB_OPS = auto() # Logging related to GitHub operations (PRs, issues, etc.) + GROUP_OPS = auto() # Logging related to tarball group operations + TASK_OPS = auto() # Logging related to task operations + ERROR = auto() # Error logging (separate from other scopes for easier filtering) + DEBUG = auto() # Debug-level logging (separate from other scopes for easier filtering) + ALL = (FUNC_ENTRY_EXIT | DOWNLOAD | VERIFICATION | STATE_OPS | + GITHUB_OPS | GROUP_OPS | TASK_OPS | ERROR | DEBUG) + + +# Global setting for logging scopes +ENABLED_LOGGING_SCOPES = LoggingScope.NONE + + +# Global variable to track call stack depth +_call_stack_depth = 0 + + +def set_logging_scopes(scopes): + """ + Set the enabled logging scopes. + + Args: + scopes: Can be: + - A LoggingScope value + - A string with comma-separated values using +/- syntax: + - "+SCOPE" to enable a scope + - "-SCOPE" to disable a scope + - "ALL" or "+ALL" to enable all scopes + - "-ALL" to disable all scopes + Examples: + "+FUNC_ENTRY_EXIT" # Enable only function entry/exit + "+FUNC_ENTRY_EXIT,-EXAMPLE_SCOPE" # Enable function entry/exit but disable example + "+ALL,-FUNC_ENTRY_EXIT" # Enable all scopes except function entry/exit + """ + global ENABLED_LOGGING_SCOPES + + if isinstance(scopes, LoggingScope): + ENABLED_LOGGING_SCOPES = scopes + return + + if isinstance(scopes, str): + # Start with no scopes enabled + ENABLED_LOGGING_SCOPES = LoggingScope.NONE + + # Split into individual scope specifications + scope_specs = [s.strip() for s in scopes.split(",")] + + for spec in scope_specs: + if not spec: + continue + + # Check for ALL special case + if spec.upper() in ["ALL", "+ALL"]: + ENABLED_LOGGING_SCOPES = LoggingScope.ALL + continue + elif spec.upper() == "-ALL": + ENABLED_LOGGING_SCOPES = LoggingScope.NONE + continue + + # Parse scope name and operation + operation = spec[0] + scope_name = spec[1:].strip().upper() + + try: + scope_enum = LoggingScope[scope_name] + if operation == '+': + ENABLED_LOGGING_SCOPES |= scope_enum + elif operation == '-': + ENABLED_LOGGING_SCOPES &= ~scope_enum + else: + logging.warning(f"Invalid operation '{operation}' in scope specification: {spec}") + except KeyError: + logging.warning(f"Unknown logging scope: {scope_name}") + + elif isinstance(scopes, list): + # Convert list to comma-separated string and process + set_logging_scopes(",".join(scopes)) + + +def is_logging_scope_enabled(scope): + """Check if a specific logging scope is enabled.""" + return bool(ENABLED_LOGGING_SCOPES & scope) def send_slack_message(webhook, msg): @@ -25,3 +123,153 @@ def sha256sum(path): for byte_block in iter(lambda: f.read(8192), b''): sha256_hash.update(byte_block) return sha256_hash.hexdigest() + + +def log_function_entry_exit(logger=None): + """ + Decorator that logs function entry and exit with timing information. + Only logs if the FUNC_ENTRY_EXIT scope is enabled. + + Args: + logger: Optional logger instance. If not provided, uses the module's logger. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + global _call_stack_depth + + if not is_logging_scope_enabled(LoggingScope.FUNC_ENTRY_EXIT): + return func(*args, **kwargs) + + if logger is None: + log = logging.getLogger(func.__module__) + else: + log = logger + + # Get context information if available + context = "" + if len(args) > 0 and hasattr(args[0], 'object'): + # For EessiTarball methods, show the tarball name and state + tarball = args[0] + filename = os.path.basename(tarball.object) + + # Format filename to show important parts + if len(filename) > 30: + parts = filename.split('-') + if len(parts) >= 6: # Ensure we have all required parts + # Get version, component, last part of architecture, and epoch + version = parts[1] + component = parts[2] + arch_last = parts[-2].split('-')[-1] # Last part of architecture + epoch = parts[-1] # includes file extension + filename = f"{version}-{component}-{arch_last}-{epoch}" + else: + # Fallback to simple truncation if format doesn't match + filename = f"{filename[:15]}...{filename[-12:]}" + + context = f" [{filename}" + if hasattr(tarball, 'state'): + context += f" in {tarball.state}" + context += "]" + + # Create indentation based on call stack depth + indent = " " * _call_stack_depth + + # Get file name and line number where the function is defined + file_name = os.path.basename(inspect.getsourcefile(func)) + source_lines, start_line = inspect.getsourcelines(func) + # Find the line with the actual function definition + def_line = next(i for i, line in enumerate(source_lines) if line.strip().startswith('def ')) + def_line_no = start_line + def_line + # Find the last non-empty line of the function + last_line = next(i for i, line in enumerate(reversed(source_lines)) if line.strip()) + last_line_no = start_line + len(source_lines) - 1 - last_line + + start_time = time.time() + log.info(f"{indent}[FUNC_ENTRY_EXIT] Entering {func.__name__} at {file_name}:{def_line_no}{context}") + _call_stack_depth += 1 + try: + result = func(*args, **kwargs) + _call_stack_depth -= 1 + end_time = time.time() + # For normal returns, show the last line of the function + log.info(f"{indent}[FUNC_ENTRY_EXIT] Leaving {func.__name__} at {file_name}:{last_line_no}" + f"{context} (took {end_time - start_time:.2f}s)") + return result + except Exception as err: + _call_stack_depth -= 1 + end_time = time.time() + # For exceptions, try to get the line number from the exception + try: + exc_line_no = err.__traceback__.tb_lineno + except AttributeError: + exc_line_no = last_line_no + log.info(f"{indent}[FUNC_ENTRY_EXIT] Leaving {func.__name__} at {file_name}:{exc_line_no}" + f"{context} with exception (took {end_time - start_time:.2f}s)") + raise err + return wrapper + return decorator + + +def log_message(scope, level, msg, *args, logger=None, **kwargs): + """ + Log a message if either: + 1. The specified scope is enabled, OR + 2. The current log level is equal to or higher than the specified level + + Args: + scope: LoggingScope value indicating which scope this logging belongs to + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + msg: Message to log + logger: Optional logger instance. If not provided, uses the root logger. + *args, **kwargs: Additional arguments to pass to the logging function + """ + log = logger or logging.getLogger() + log_level = getattr(logging, level.upper()) + + # Check if either condition is met + if not (is_logging_scope_enabled(scope) or log_level >= log.getEffectiveLevel()): + return + + # Create indentation based on call stack depth + indent = " " * _call_stack_depth + # Add scope to the message + scoped_msg = f"[{scope.name}] {msg}" + indented_msg = f"{indent}{scoped_msg}" + + # If scope is enabled, use the temporary handler + if is_logging_scope_enabled(scope): + # Save original handlers + original_handlers = list(log.handlers) + + # Create a temporary handler that accepts all levels + temp_handler = logging.StreamHandler(sys.stdout) + temp_handler.setLevel(logging.DEBUG) + temp_handler.setFormatter(logging.Formatter('%(levelname)-8s: %(message)s')) + + try: + # Remove existing handlers temporarily + for handler in original_handlers: + log.removeHandler(handler) + + # Add temporary handler + log.addHandler(temp_handler) + + # Log the message + log_func = getattr(log, level.lower()) + log_func(indented_msg, *args, **kwargs) + finally: + log.removeHandler(temp_handler) + # Restore original handlers + for handler in original_handlers: + if handler not in log.handlers: + log.addHandler(handler) + # Only use normal logging if scope is not enabled AND level is high enough + elif not is_logging_scope_enabled(scope) and log_level >= log.getEffectiveLevel(): + # Use normal logging with level check + log_func = getattr(log, level.lower()) + log_func(indented_msg, *args, **kwargs) + +# Example usage: +# log_message(LoggingScope.DOWNLOAD, 'INFO', "Downloading file: %s", filename) +# log_message(LoggingScope.ERROR, 'ERROR', "Failed to download: %s", error_msg) diff --git a/scripts/check-stratum-servers.py b/scripts/check-stratum-servers.py index de4270d9..4e35b09e 100755 --- a/scripts/check-stratum-servers.py +++ b/scripts/check-stratum-servers.py @@ -9,7 +9,8 @@ import yaml # Default location for EESSI's Ansible group vars file containing the CVMFS settings. -DEFAULT_ANSIBLE_GROUP_VARS_LOCATION = 'https://raw.githubusercontent.com/EESSI/filesystem-layer/main/inventory/group_vars/all.yml' +DEFAULT_ANSIBLE_GROUP_VARS_LOCATION = \ + 'https://raw.githubusercontent.com/EESSI/filesystem-layer/main/inventory/group_vars/all.yml' # Default fully qualified CVMFS repository name DEFAULT_CVMFS_FQRN = 'software.eessi.io' # Maximum amount of time (in minutes) that a Stratum 1 is allowed to not having performed a snapshot. @@ -32,8 +33,8 @@ def find_stratum_urls(vars_file, fqrn): """Find all Stratum 0/1 URLs in a given Ansible YAML vars file that contains the EESSI CVMFS configuration.""" try: group_vars = urllib.request.urlopen(vars_file) - except: - error(f'Cannot read the file that contains the Stratum 1 URLs from {vars_file}!') + except Exception as err: + error(f'Cannot read the file that contains the Stratum 1 URLs from {vars_file}!\nException: {err}') try: group_vars_yaml = yaml.safe_load(group_vars) s1_urls = group_vars_yaml['eessi_cvmfs_server_urls'][0]['urls'] @@ -44,8 +45,8 @@ def find_stratum_urls(vars_file, fqrn): break else: error(f'Could not find Stratum 0 URL in {vars_file}!') - except: - error(f'Cannot parse the yaml file from {vars_file}!') + except Exception as err: + error(f'Cannot parse the yaml file from {vars_file}!\nException: {err}') return s0_url, s1_urls @@ -64,7 +65,7 @@ def check_revisions(stratum_urls, fqrn): revisions[stratum] = int(rev_matches[0]) else: errors.append(f'Could not find revision number for stratum {stratum}!') - except urllib.error.HTTPError as e: + except urllib.error.HTTPError: errors.append(f'Could not connect to {stratum}!') # Check if all revisions are the same. @@ -95,10 +96,11 @@ def check_snapshots(s1_urls, fqrn, max_snapshot_delay=DEFAULT_MAX_SNAPSHOT_DELAY # Stratum 1 servers are supposed to make a snapshot every few minutes, # so let's check if it is not too far behind. if now - last_snapshot_time > datetime.timedelta(minutes=max_snapshot_delay): + time_diff = (now - last_snapshot_time).seconds / 60 errors.append( - f'Stratum 1 {s1} has made its last snapshot {(now - last_snapshot_time).seconds / 60:.0f} minutes ago!') - except urllib.error.HTTPError as e: - errors.append(f'Could not connect to {s1_json}!') + f'Stratum 1 {s1} has made its last snapshot {time_diff:.0f} minutes ago!') + except urllib.error.HTTPError: + errors.append(f'Could not connect to {s1_snapshot_file}!') if last_snapshots: # Get the Stratum 1 with the most recent snapshot...