diff --git a/.flake8 b/.flake8
new file mode 100644
index 00000000..b6b309e3
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,14 @@
+# This file is part of the EESSI filesystem layer,
+# see https://github.com/EESSI/filesystem-layer
+#
+# author: Thomas Roeblitz (@trz42)
+#
+# license: GPLv2
+#
+
+[flake8]
+max-line-length = 120
+
+# ignore "Black would make changes" produced by flake8-black
+# see also https://github.com/houndci/hound/issues/1769
+extend-ignore = BLK100
diff --git a/.github/workflows/check-flake8.yml b/.github/workflows/check-flake8.yml
new file mode 100644
index 00000000..f0ebe250
--- /dev/null
+++ b/.github/workflows/check-flake8.yml
@@ -0,0 +1,36 @@
+# This file is part of the EESSI filesystem layer,
+# see https://github.com/EESSI/filesystem-layer
+#
+# author: Thomas Roeblitz (@trz42)
+#
+# license: GPLv2
+#
+
+name: Run tests
+on: [push, pull_request]
+# Declare default permissions as read only.
+permissions: read-all
+jobs:
+ test:
+ runs-on: ubuntu-22.04
+ strategy:
+ matrix:
+ python: [3.7, 3.8, 3.9, '3.10', '3.11', '3.12']
+ fail-fast: false
+ steps:
+ - name: checkout
+ uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # v3.1.0
+
+ - name: set up Python
+ uses: actions/setup-python@13ae5bb136fac2878aff31522b9efb785519f984 # v4.3.0
+ with:
+ python-version: ${{matrix.python}}
+
+ - name: Install required Python packages + pytest + flake8
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install --upgrade flake8
+
+ - name: Run flake8 to verify PEP8-compliance of Python code
+ run: |
+ flake8
diff --git a/scripts/automated_ingestion/automated_ingestion.cfg.example b/scripts/automated_ingestion/automated_ingestion.cfg.example
index 68df3e4e..18009d88 100644
--- a/scripts/automated_ingestion/automated_ingestion.cfg.example
+++ b/scripts/automated_ingestion/automated_ingestion.cfg.example
@@ -63,13 +63,61 @@ pr_body = A new tarball has been staged for {pr_url}.
```
+
+
+ Overview of tarball contents
+
+ {tar_overview}
+
+
+# Method for creating staging PRs:
+# - 'individual': create one PR per tarball (old method)
+# - 'grouped': group tarballs by link2pr and create one PR per group (new method)
+staging_pr_method = individual
+
+# Template for individual tarball PRs
+individual_pr_body = A new tarball has been staged for {pr_url}.
+ Please review the contents of this tarball carefully.
+ Merging this PR will lead to automatic ingestion of the tarball to the repository {cvmfs_repo}.
+
+
+ Metadata of tarball
+
+ ```
+ {metadata}
+ ```
+
+
+
+
+ Overview of tarball contents
+
+ {tar_overview}
+
+
+
+# Template for grouped tarball PRs
+grouped_pr_body = A group of tarballs has been staged for {pr_url}.
+ Please review the contents of these tarballs carefully.
+ Merging this PR will lead to automatic ingestion of the approved tarballs to the repository {cvmfs_repo}.
+ Unchecked tarballs will be marked as rejected.
+
+ {tarballs}
+
Overview of tarball contents
{tar_overview}
+
+ {metadata}
+
+# Template for payload overview
+task_summary_payload_template =
+ {payload_overview}
+
[slack]
ingestion_notification = yes
diff --git a/scripts/automated_ingestion/automated_ingestion.py b/scripts/automated_ingestion/automated_ingestion.py
index 92dac552..974f8497 100755
--- a/scripts/automated_ingestion/automated_ingestion.py
+++ b/scripts/automated_ingestion/automated_ingestion.py
@@ -1,12 +1,15 @@
#!/usr/bin/env python3
-from eessitarball import EessiTarball
-from pid.decorator import pidfile
+from eessitarball import EessiTarball, EessiTarballGroup
+from eessi_data_object import EESSIDataAndSignatureObject
+from eessi_task import EESSITask, TaskState
+from eessi_task_description import EESSITaskDescription
+from s3_bucket import EESSIS3Bucket
+from pid.decorator import pidfile # noqa: F401
from pid import PidFileError
+from utils import log_function_entry_exit, log_message, LoggingScope, set_logging_scopes
import argparse
-import boto3
-import botocore
import configparser
import github
import json
@@ -14,6 +17,8 @@
import os
import pid
import sys
+from pathlib import Path
+from typing import List
REQUIRED_CONFIG = {
'secrets': ['aws_secret_access_key', 'aws_access_key_id', 'github_pat'],
@@ -33,83 +38,350 @@
def error(msg, code=1):
"""Print an error and exit."""
- logging.error(msg)
+ log_message(LoggingScope.ERROR, 'ERROR', msg)
sys.exit(code)
-def find_tarballs(s3, bucket, extension='.tar.gz', metadata_extension='.meta.txt'):
- """Return a list of all tarballs in an S3 bucket that have a metadata file with the given extension (and same filename)."""
+def find_tarballs(s3_bucket, extension='.tar.gz', metadata_extension='.meta.txt'):
+ """
+ Return a list of all tarballs in an S3 bucket that have a metadata file with
+ the given extension (and same filename).
+ """
# TODO: list_objects_v2 only returns up to 1000 objects
- s3_objects = s3.list_objects_v2(Bucket=bucket).get('Contents', [])
+ s3_objects = s3_bucket.list_objects_v2().get('Contents', [])
files = [obj['Key'] for obj in s3_objects]
tarballs = [
file
for file in files
- if file.endswith(extension)
- and file + metadata_extension in files
+ if file.endswith(extension) and file + metadata_extension in files
]
return tarballs
+@log_function_entry_exit()
+def find_tarball_groups(s3_bucket, config, extension='.tar.gz', metadata_extension='.meta.txt'):
+ """Return a dictionary of tarball groups, keyed by (repo, pr_number)."""
+ tarballs = find_tarballs(s3_bucket, extension, metadata_extension)
+ groups = {}
+
+ for tarball in tarballs:
+ # Download metadata to get link2pr info
+ metadata_file = tarball + metadata_extension
+ local_metadata = os.path.join(config['paths']['download_dir'], os.path.basename(metadata_file))
+
+ try:
+ s3_bucket.download_file(metadata_file, local_metadata)
+ with open(local_metadata, 'r') as meta:
+ metadata = json.load(meta)
+ repo = metadata['link2pr']['repo']
+ pr = metadata['link2pr']['pr']
+ group_key = (repo, pr)
+
+ if group_key not in groups:
+ groups[group_key] = []
+ groups[group_key].append(tarball)
+ except Exception as err:
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to process metadata for %s: %s", tarball, err)
+ continue
+ finally:
+ # Clean up downloaded metadata file
+ if os.path.exists(local_metadata):
+ os.remove(local_metadata)
+
+ return groups
+
+
+@log_function_entry_exit()
def parse_config(path):
"""Parse the configuration file."""
config = configparser.ConfigParser()
try:
config.read(path)
- except:
- error(f'Unable to read configuration file {path}!')
+ except Exception as err:
+ error(f'Unable to read configuration file {path}!\nException: {err}')
# Check if all required configuration parameters/sections can be found.
for section in REQUIRED_CONFIG.keys():
- if not section in config:
+ if section not in config:
error(f'Missing section "{section}" in configuration file {path}.')
for item in REQUIRED_CONFIG[section]:
- if not item in config[section]:
+ if item not in config[section]:
error(f'Missing configuration item "{item}" in section "{section}" of configuration file {path}.')
+
+ # Validate staging_pr_method
+ staging_method = config['github'].get('staging_pr_method', 'individual')
+ if staging_method not in ['individual', 'grouped']:
+ error(
+ f'Invalid staging_pr_method: "{staging_method}" in configuration file {path}. '
+ 'Must be either "individual" or "grouped".'
+ )
+
+ # Validate PR body templates
+ if staging_method == 'individual' and 'individual_pr_body' not in config['github']:
+ error(f'Missing "individual_pr_body" in configuration file {path}.')
+ if staging_method == 'grouped' and 'grouped_pr_body' not in config['github']:
+ error(f'Missing "grouped_pr_body" in configuration file {path}.')
+
return config
+@log_function_entry_exit()
def parse_args():
"""Parse the command-line arguments."""
parser = argparse.ArgumentParser()
+
+ # Logging options
+ logging_group = parser.add_argument_group('Logging options')
+ logging_group.add_argument('--log-file',
+ help='Path to log file (overrides config file setting)')
+ logging_group.add_argument('--console-level',
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
+ help='Logging level for console output (overrides config file setting)')
+ logging_group.add_argument('--file-level',
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
+ help='Logging level for file output (overrides config file setting)')
+ logging_group.add_argument('--quiet',
+ action='store_true',
+ help='Suppress console output (overrides all other console settings)')
+ logging_group.add_argument('--log-scopes',
+ help='Comma-separated list of logging scopes using +/- syntax. '
+ 'Examples: "+FUNC_ENTRY_EXIT" (enable only function entry/exit), '
+ '"+ALL,-FUNC_ENTRY_EXIT" (enable all except function entry/exit), '
+ '"+FUNC_ENTRY_EXIT,-EXAMPLE_SCOPE" (enable function entry/exit but disable example)')
+
+ # Existing arguments
parser.add_argument('-c', '--config', type=str, help='path to configuration file',
default='automated_ingestion.cfg', dest='config')
parser.add_argument('-d', '--debug', help='enable debug mode', action='store_true', dest='debug')
- parser.add_argument('-l', '--list', help='only list available tarballs', action='store_true', dest='list_only')
- args = parser.parse_args()
- return args
+ parser.add_argument('-l', '--list', help='only list available tarballs or tasks', action='store_true',
+ dest='list_only')
+ parser.add_argument('--task-based', help='use task-based ingestion instead of tarball-based. '
+ 'Optionally specify comma-separated list of extensions (default: .task)',
+ nargs='?', const='.task', default=False)
+
+ return parser.parse_args()
+
+
+@log_function_entry_exit()
+def setup_logging(config, args):
+ """
+ Configure logging based on configuration file and command line arguments.
+ Command line arguments take precedence over config file settings.
+
+ Args:
+ config: Configuration dictionary
+ args: Parsed command line arguments
+ """
+ # Get settings from config file
+ log_file = config['logging'].get('filename')
+ config_console_level = LOG_LEVELS.get(config['logging'].get('level', 'INFO').upper(), logging.INFO)
+ config_file_level = LOG_LEVELS.get(config['logging'].get('file_level', 'DEBUG').upper(), logging.DEBUG)
+
+ # Override with command line arguments if provided
+ log_file = args.log_file if args.log_file else log_file
+ console_level = getattr(logging, args.console_level) if args.console_level else config_console_level
+ file_level = getattr(logging, args.file_level) if args.file_level else config_file_level
+
+ # Debug mode overrides console level
+ if args.debug:
+ console_level = logging.DEBUG
+
+ # Set up logging scopes
+ if args.log_scopes:
+ set_logging_scopes(args.log_scopes)
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Enabled logging scopes: %s", args.log_scopes)
+
+ # Create logger
+ logger = logging.getLogger()
+ logger.setLevel(logging.DEBUG) # Set root logger to lowest level
+
+ # Create formatters
+ console_formatter = logging.Formatter('%(levelname)-8s: %(message)s')
+ file_formatter = logging.Formatter('%(asctime)s - %(levelname)-8s: %(message)s')
+
+ # Console handler (only if not quiet)
+ if not args.quiet:
+ console_handler = logging.StreamHandler(sys.stdout)
+ console_handler.setLevel(console_level)
+ console_handler.setFormatter(console_formatter)
+ logger.addHandler(console_handler)
+
+ # File handler (if log file is specified)
+ if log_file:
+ # Ensure log directory exists
+ log_path = Path(log_file)
+ log_path.parent.mkdir(parents=True, exist_ok=True)
+
+ file_handler = logging.FileHandler(log_file)
+ file_handler.setLevel(file_level)
+ file_handler.setFormatter(file_formatter)
+ logger.addHandler(file_handler)
+
+ return logger
@pid.decorator.pidfile('automated_ingestion.pid')
+@log_function_entry_exit()
def main():
"""Main function."""
args = parse_args()
config = parse_config(args.config)
- log_file = config['logging'].get('filename', None)
- log_format = config['logging'].get('format', '%(levelname)s:%(message)s')
- log_level = LOG_LEVELS.get(config['logging'].get('level', 'INFO').upper(), logging.WARN)
- log_level = logging.DEBUG if args.debug else log_level
- logging.basicConfig(filename=log_file, format=log_format, level=log_level)
+ setup_logging(config, args)
+
# TODO: check configuration: secrets, paths, permissions on dirs, etc
gh_pat = config['secrets']['github_pat']
gh_staging_repo = github.Github(gh_pat).get_repo(config['github']['staging_repo'])
- s3 = boto3.client(
- 's3',
- aws_access_key_id=config['secrets']['aws_access_key_id'],
- aws_secret_access_key=config['secrets']['aws_secret_access_key'],
- )
buckets = json.loads(config['aws']['staging_buckets'])
for bucket, cvmfs_repo in buckets.items():
- tarballs = find_tarballs(s3, bucket)
- if args.list_only:
- for num, tarball in enumerate(tarballs):
- print(f'[{bucket}] {num}: {tarball}')
+ # Create our custom S3 bucket for this bucket
+ s3_bucket = EESSIS3Bucket(config, bucket)
+
+ if args.task_based:
+ # Task-based listing
+ extensions = args.task_based.split(',')
+ tasks = find_deployment_tasks(s3_bucket, extensions)
+ if args.list_only:
+ log_message(LoggingScope.GROUP_OPS, 'INFO', "#tasks: %d", len(tasks))
+ for num, task in enumerate(tasks):
+ log_message(LoggingScope.GROUP_OPS, 'INFO', "[%s] %d: %s", bucket, num, task)
+ else:
+ # Process each task file
+ for task_path in tasks:
+ log_message(LoggingScope.GROUP_OPS, 'INFO', "Processing task: %s", task_path)
+ try:
+ # Create EESSITask for the task file
+ try:
+ task = EESSITask(
+ EESSITaskDescription(EESSIDataAndSignatureObject(config, task_path, s3_bucket)),
+ config, cvmfs_repo, gh_staging_repo
+ )
+
+ except Exception as err:
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to create EESSITask for task %s: %s",
+ task_path, str(err))
+ continue
+
+ log_message(LoggingScope.GROUP_OPS, 'INFO', "Task: %s", task)
+
+ previous_state = None
+ current_state = task.determine_state()
+ log_message(LoggingScope.GROUP_OPS, 'INFO', "Task '%s' is in state '%s'",
+ task_path, current_state.name)
+ while (current_state is not None and
+ current_state != TaskState.DONE and
+ previous_state != current_state):
+ previous_state = current_state
+ log_message(LoggingScope.GROUP_OPS, 'INFO',
+ "Task '%s': BEFORE handle(): previous state = '%s', current state = '%s'",
+ task_path, previous_state.name, current_state.name)
+ current_state = task.handle()
+ log_message(LoggingScope.GROUP_OPS, 'INFO',
+ "Task '%s': AFTER handle(): previous state = '%s', current state = '%s'",
+ task_path, previous_state.name, current_state.name)
+
+ # # TODO: update the information shown below (what makes sense to show?)
+ # # Log information about the task
+ # task_object = task.description.task_object
+ # log_message(LoggingScope.GROUP_OPS, 'INFO', "Task file: %s", task_object.local_file_path)
+ # log_message(LoggingScope.GROUP_OPS, 'INFO', "Signature file: %s", task_object.local_sig_path)
+ # log_message(LoggingScope.GROUP_OPS, 'INFO', "Signature verified: %s",
+ # task.description.signature_verified)
+
+ # # Log the ETags of the downloaded task file
+ # file_etag, sig_etag = task.description.task_object.get_etags()
+ # log_message(LoggingScope.GROUP_OPS, 'INFO', "Task file %s has ETag: %s", task_path, file_etag)
+ # log_message(LoggingScope.GROUP_OPS, 'INFO', "Task signature %s has ETag: %s",
+ # task.description.task_object.remote_sig_path, sig_etag)
+
+ # # TODO: Process the task file contents
+ # # This would involve reading the task file, parsing its contents,
+ # # and performing the required actions based on the task type
+ # log_message(LoggingScope.GROUP_OPS, 'INFO', "TODO: Processing task file: %s", task_path)
+ # task.handle()
+
+ except Exception as err:
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to process task %s: %s", task_path, str(err))
+ continue
else:
- for tarball in tarballs:
- tar = EessiTarball(tarball, config, gh_staging_repo, s3, bucket, cvmfs_repo)
- tar.run_handler()
+ # Original tarball-based processing
+ if config['github'].get('staging_pr_method', 'individual') == 'grouped':
+ # use new grouped PR method
+ tarball_groups = find_tarball_groups(s3_bucket, config)
+ if args.list_only:
+ log_message(LoggingScope.GROUP_OPS, 'INFO', "#tarball_groups: %d", len(tarball_groups))
+ for (repo, pr_id), tarballs in tarball_groups.items():
+ log_message(LoggingScope.GROUP_OPS, 'INFO', " %s#%s: #tarballs %d", repo, pr_id, len(tarballs))
+ else:
+ for (repo, pr_id), tarballs in tarball_groups.items():
+ if tarballs:
+ # Create a group for these tarballs
+ group = EessiTarballGroup(tarballs[0], config, gh_staging_repo, s3_bucket, cvmfs_repo)
+ log_message(LoggingScope.GROUP_OPS, 'INFO', "group created\n%s",
+ group.to_string(oneline=True))
+ group.process_group(tarballs)
+ else:
+ # use old individual PR method
+ tarballs = find_tarballs(s3_bucket)
+ if args.list_only:
+ for num, tarball in enumerate(tarballs):
+ log_message(LoggingScope.GROUP_OPS, 'INFO', "[%s] %d: %s", bucket, num, tarball)
+ else:
+ for tarball in tarballs:
+ tar = EessiTarball(tarball, config, gh_staging_repo, s3_bucket, cvmfs_repo)
+ tar.run_handler()
+
+
+@log_function_entry_exit()
+def find_deployment_tasks(s3_bucket, extensions: List[str] = None) -> List[str]:
+ """
+ Return a list of all task files in an S3 bucket with the given extensions,
+ but only if a corresponding payload file exists (same name without extension).
+
+ Args:
+ s3_bucket: EESSIS3Bucket instance
+ extensions: List of file extensions to look for (default: ['.task'])
+
+ Returns:
+ List of task filenames found in the bucket that have a corresponding payload
+ """
+ if extensions is None:
+ extensions = ['.task']
+
+ files = []
+ continuation_token = None
+
+ while True:
+ # List objects with pagination
+ if continuation_token:
+ response = s3_bucket.list_objects_v2(
+ ContinuationToken=continuation_token
+ )
+ else:
+ response = s3_bucket.list_objects_v2()
+
+ # Add files from this page
+ files.extend([obj['Key'] for obj in response.get('Contents', [])])
+
+ # Check if there are more pages
+ if response.get('IsTruncated'):
+ continuation_token = response.get('NextContinuationToken')
+ else:
+ break
+
+ # Create a set of all files for faster lookup
+ file_set = set(files)
+
+ # Return only task files that have a corresponding payload
+ result = []
+ for file in files:
+ for ext in extensions:
+ if file.endswith(ext) and file[:-len(ext)] in file_set:
+ result.append(file)
+ break # Found a matching extension, no need to check others
+
+ return result
if __name__ == '__main__':
diff --git a/scripts/automated_ingestion/eessi_data_object.py b/scripts/automated_ingestion/eessi_data_object.py
new file mode 100644
index 00000000..c7adc05b
--- /dev/null
+++ b/scripts/automated_ingestion/eessi_data_object.py
@@ -0,0 +1,334 @@
+import subprocess
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import configparser
+
+from utils import log_function_entry_exit, log_message, LoggingScope
+from remote_storage import RemoteStorageClient, DownloadMode
+
+
+@dataclass
+class EESSIDataAndSignatureObject:
+ """Class representing an EESSI data file and its signature in remote storage and locally."""
+
+ # Configuration
+ config: configparser.ConfigParser
+
+ # Remote paths
+ remote_file_path: str # Path to data file in remote storage
+ remote_sig_path: str # Path to signature file in remote storage
+
+ # Local paths
+ local_file_path: Path # Path to local data file
+ local_sig_path: Path # Path to local signature file
+
+ # Remote storage client
+ remote_client: RemoteStorageClient
+
+ @log_function_entry_exit()
+ def __init__(self, config: configparser.ConfigParser, remote_file_path: str, remote_client: RemoteStorageClient):
+ """
+ Initialize an EESSI data and signature object handler.
+
+ Args:
+ config: Configuration object containing remote storage and local directory information
+ remote_file_path: Path to data file in remote storage
+ remote_client: Remote storage client implementing the RemoteStorageClient protocol
+ """
+ self.config = config
+ self.remote_file_path = remote_file_path
+ sig_ext = config['signatures']['signature_file_extension']
+ self.remote_sig_path = remote_file_path + sig_ext
+
+ # Set up local paths
+ local_dir = Path(config['paths']['download_dir'])
+ # Use the full remote path structure, removing any leading slashes
+ remote_path = remote_file_path.lstrip('/')
+ self.local_file_path = local_dir.joinpath(remote_path)
+ self.local_sig_path = local_dir.joinpath(remote_path + sig_ext)
+ self.remote_client = remote_client
+
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Initialized EESSIDataAndSignatureObject for %s", remote_file_path)
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Local file path: %s", self.local_file_path)
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Local signature path: %s", self.local_sig_path)
+
+ def _get_etag_file_path(self, local_path: Path) -> Path:
+ """Get the path to the .etag file for a given local file."""
+ return local_path.with_suffix('.etag')
+
+ def _get_local_etag(self, local_path: Path) -> Optional[str]:
+ """Get the ETag of a local file from its .etag file."""
+ etag_path = self._get_etag_file_path(local_path)
+ if etag_path.exists():
+ try:
+ with open(etag_path, 'r') as f:
+ return f.read().strip()
+ except Exception as err:
+ log_message(LoggingScope.DEBUG, 'WARNING', "Failed to read ETag file %s: %s", etag_path, str(err))
+ return None
+ return None
+
+ def get_etags(self) -> tuple[Optional[str], Optional[str]]:
+ """
+ Get the ETags of both the data file and its signature.
+
+ Returns:
+ Tuple containing (data_file_etag, signature_file_etag)
+ """
+ return (
+ self._get_local_etag(self.local_file_path),
+ self._get_local_etag(self.local_sig_path)
+ )
+
+ @log_function_entry_exit()
+ def verify_signature(self) -> bool:
+ """
+ Verify the signature of the data file using the corresponding signature file.
+
+ Returns:
+ bool: True if the signature is valid or if signatures are not required, False otherwise
+ """
+ # Check if signature file exists
+ if not self.local_sig_path.exists():
+ log_message(LoggingScope.VERIFICATION, 'WARNING', "Signature file %s is missing",
+ self.local_sig_path)
+
+ # If signatures are required, return failure
+ if self.config['signatures'].getboolean('signatures_required', True):
+ log_message(LoggingScope.ERROR, 'ERROR', "Signature file %s is missing and signatures are required",
+ self.local_sig_path)
+ return False
+ else:
+ log_message(LoggingScope.VERIFICATION, 'INFO',
+ "Signature file %s is missing, but signatures are not required",
+ self.local_sig_path)
+ return True
+
+ # If signatures are provided, we should always verify them, regardless of the signatures_required setting
+ verify_runenv = self.config['signatures']['signature_verification_runenv'].split()
+ verify_script = self.config['signatures']['signature_verification_script']
+ allowed_signers_file = self.config['signatures']['allowed_signers_file']
+
+ # Check if verification tools exist
+ if not Path(verify_script).exists():
+ log_message(LoggingScope.ERROR, 'ERROR',
+ "Unable to verify signature: verification script %s does not exist", verify_script)
+ return False
+
+ if not Path(allowed_signers_file).exists():
+ log_message(LoggingScope.ERROR, 'ERROR',
+ "Unable to verify signature: allowed signers file %s does not exist", allowed_signers_file)
+ return False
+
+ # Run the verification command with named parameters
+ cmd = verify_runenv + [
+ verify_script,
+ '--verify',
+ '--allowed-signers-file', allowed_signers_file,
+ '--file', str(self.local_file_path),
+ '--signature-file', str(self.local_sig_path)
+ ]
+ log_message(LoggingScope.VERIFICATION, 'INFO', "Running command: %s", ' '.join(cmd))
+
+ try:
+ result = subprocess.run(cmd, capture_output=True, text=True)
+ if result.returncode == 0:
+ log_message(LoggingScope.VERIFICATION, 'INFO',
+ "Successfully verified signature for %s", self.local_file_path)
+ log_message(LoggingScope.VERIFICATION, 'DEBUG', " stdout: %s", result.stdout)
+ log_message(LoggingScope.VERIFICATION, 'DEBUG', " stderr: %s", result.stderr)
+ return True
+ else:
+ log_message(LoggingScope.ERROR, 'ERROR',
+ "Signature verification failed for %s", self.local_file_path)
+ log_message(LoggingScope.ERROR, 'ERROR', " stdout: %s", result.stdout)
+ log_message(LoggingScope.ERROR, 'ERROR', " stderr: %s", result.stderr)
+ return False
+ except Exception as err:
+ log_message(LoggingScope.ERROR, 'ERROR',
+ "Error during signature verification for %s: %s",
+ self.local_file_path, str(err))
+ return False
+
+ @log_function_entry_exit()
+ def download(self, mode: DownloadMode = DownloadMode.CHECK_REMOTE) -> bool:
+ """
+ Download data file and signature based on the specified mode.
+
+ Args:
+ mode: Download mode to use
+
+ Returns:
+ True if files were downloaded, False otherwise
+ """
+ # If mode is FORCE, we always download regardless of local or remote state
+ if mode == DownloadMode.FORCE:
+ should_download = True
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Forcing download of %s", self.remote_file_path)
+ # For CHECK_REMOTE mode, check if we can optimize
+ elif mode == DownloadMode.CHECK_REMOTE:
+ # Optimization: Check if local files exist first
+ local_files_exist = (
+ self.local_file_path.exists() and
+ self.local_sig_path.exists()
+ )
+
+ # If files don't exist locally, we can skip ETag checks
+ if not local_files_exist:
+ log_message(LoggingScope.DOWNLOAD, 'INFO',
+ "Local files missing, skipping ETag checks and downloading %s",
+ self.remote_file_path)
+ should_download = True
+ else:
+ # First check if we have local ETags
+ try:
+ local_file_etag = self._get_local_etag(self.local_file_path)
+ local_sig_etag = self._get_local_etag(self.local_sig_path)
+
+ if local_file_etag:
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Local file ETag: %s", local_file_etag)
+ else:
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "No local file ETag found")
+ if local_sig_etag:
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Local signature ETag: %s", local_sig_etag)
+ else:
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "No local signature ETag found")
+
+ # If we don't have local ETags, we need to download
+ if not local_file_etag or not local_sig_etag:
+ should_download = True
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Missing local ETags, downloading %s",
+ self.remote_file_path)
+ else:
+ # Get remote ETags and compare
+ remote_file_etag = self.remote_client.get_metadata(self.remote_file_path)['ETag']
+ remote_sig_etag = self.remote_client.get_metadata(self.remote_sig_path)['ETag']
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Remote file ETag: %s", remote_file_etag)
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Remote signature ETag: %s", remote_sig_etag)
+
+ should_download = (
+ remote_file_etag != local_file_etag or
+ remote_sig_etag != local_sig_etag
+ )
+ if should_download:
+ if remote_file_etag != local_file_etag:
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "File ETag changed from %s to %s",
+ local_file_etag, remote_file_etag)
+ if remote_sig_etag != local_sig_etag:
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Signature ETag changed from %s to %s",
+ local_sig_etag, remote_sig_etag)
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Remote files have changed, downloading %s",
+ self.remote_file_path)
+ else:
+ log_message(LoggingScope.DOWNLOAD, 'INFO',
+ "Remote files unchanged, skipping download of %s",
+ self.remote_file_path)
+ except Exception as etag_err:
+ # If we get any error with ETags, we'll just download the files
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Error handling ETags, will download files: %s",
+ str(etag_err))
+ should_download = True
+ else: # CHECK_LOCAL
+ should_download = (
+ not self.local_file_path.exists() or
+ not self.local_sig_path.exists()
+ )
+ if should_download:
+ if not self.local_file_path.exists():
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Local file missing: %s", self.local_file_path)
+ if not self.local_sig_path.exists():
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Local signature missing: %s", self.local_sig_path)
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Local files missing, downloading %s",
+ self.remote_file_path)
+ else:
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Local files exist, skipping download of %s",
+ self.remote_file_path)
+
+ if not should_download:
+ return False
+
+ # Ensure local directory exists
+ self.local_file_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Download files
+ try:
+ # Download the main file first
+ self.remote_client.download(self.remote_file_path, str(self.local_file_path))
+
+ # Get and log the ETag of the downloaded file
+ try:
+ file_etag = self._get_local_etag(self.local_file_path)
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Downloaded %s with ETag: %s",
+ self.remote_file_path, file_etag)
+ except Exception as etag_err:
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Error getting ETag for %s: %s",
+ self.remote_file_path, str(etag_err))
+
+ # Try to download the signature file
+ try:
+ self.remote_client.download(self.remote_sig_path, str(self.local_sig_path))
+ try:
+ sig_etag = self._get_local_etag(self.local_sig_path)
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Downloaded %s with ETag: %s",
+ self.remote_sig_path, sig_etag)
+ except Exception as etag_err:
+ log_message(LoggingScope.DOWNLOAD, 'DEBUG', "Error getting ETag for %s: %s",
+ self.remote_sig_path, str(etag_err))
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Successfully downloaded %s and its signature",
+ self.remote_file_path)
+ except Exception as sig_err:
+ # Check if signatures are required
+ if self.config['signatures'].getboolean('signatures_required', True):
+ # If signatures are required, clean up everything since we can't proceed
+ if self.local_file_path.exists():
+ self.local_file_path.unlink()
+ # Clean up etag files regardless of whether their data files exist
+ file_etag_path = self._get_etag_file_path(self.local_file_path)
+ if file_etag_path.exists():
+ file_etag_path.unlink()
+ sig_etag_path = self._get_etag_file_path(self.local_sig_path)
+ if sig_etag_path.exists():
+ sig_etag_path.unlink()
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to download required signature for %s: %s",
+ self.remote_file_path, str(sig_err))
+ raise
+ else:
+ # If signatures are optional, just clean up any partial signature files
+ if self.local_sig_path.exists():
+ self.local_sig_path.unlink()
+ sig_etag_path = self._get_etag_file_path(self.local_sig_path)
+ if sig_etag_path.exists():
+ sig_etag_path.unlink()
+ log_message(LoggingScope.DOWNLOAD, 'WARNING', "Failed to download optional signature for %s: %s",
+ self.remote_file_path, str(sig_err))
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Successfully downloaded %s (signature optional)",
+ self.remote_file_path)
+
+ return True
+ except Exception as err:
+ # This catch block is only for errors in the main file download
+ # Clean up partially downloaded files and their etags
+ if self.local_file_path.exists():
+ self.local_file_path.unlink()
+ if self.local_sig_path.exists():
+ self.local_sig_path.unlink()
+ # Clean up etag files regardless of whether their data files exist
+ file_etag_path = self._get_etag_file_path(self.local_file_path)
+ if file_etag_path.exists():
+ file_etag_path.unlink()
+ sig_etag_path = self._get_etag_file_path(self.local_sig_path)
+ if sig_etag_path.exists():
+ sig_etag_path.unlink()
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to download %s: %s", self.remote_file_path, str(err))
+ raise
+
+ @log_function_entry_exit()
+ def get_url(self) -> str:
+ """Get the URL of the data file."""
+ return f"https://{self.remote_client.bucket}.s3.amazonaws.com/{self.remote_file_path}"
+
+ def __str__(self) -> str:
+ """Return a string representation of the EESSI data and signature object."""
+ return f"EESSIDataAndSignatureObject({self.remote_file_path})"
diff --git a/scripts/automated_ingestion/eessi_task.py b/scripts/automated_ingestion/eessi_task.py
new file mode 100644
index 00000000..bd863946
--- /dev/null
+++ b/scripts/automated_ingestion/eessi_task.py
@@ -0,0 +1,1404 @@
+from enum import Enum, auto
+from typing import Dict, List, Tuple, Optional
+from functools import total_ordering
+
+import base64
+import os
+import subprocess
+import traceback
+
+from eessi_data_object import EESSIDataAndSignatureObject
+from eessi_task_action import EESSITaskAction
+from eessi_task_description import EESSITaskDescription
+from eessi_task_payload import EESSITaskPayload
+from utils import send_slack_message, log_message, LoggingScope, log_function_entry_exit
+
+from github import Github, GithubException, InputGitTreeElement, UnknownObjectException
+from github.PullRequest import PullRequest
+from github.Branch import Branch
+
+
+class SequenceStatus(Enum):
+ DOES_NOT_EXIST = auto()
+ IN_PROGRESS = auto()
+ FINISHED = auto()
+
+
+@total_ordering
+class TaskState(Enum):
+ UNDETERMINED = auto() # The task state was not determined yet
+ NEW_TASK = auto() # The task has been created but not yet processed
+ PAYLOAD_STAGED = auto() # The task's payload has been staged to the Stratum-0
+ PULL_REQUEST = auto() # A PR for the task has been created or updated in some staging repository
+ APPROVED = auto() # The PR for the task has been approved
+ REJECTED = auto() # The PR for the task has been rejected
+ INGESTED = auto() # The task's payload has been applied to the target CernVM-FS repository
+ DONE = auto() # The task has been completed
+
+ @classmethod
+ def from_string(cls, name, default=None, case_sensitive=False):
+ log_message(LoggingScope.TASK_OPS, 'INFO', "from_string: %s", name)
+ if case_sensitive:
+ to_return = cls.__members__.get(name, default)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "from_string will return: %s", to_return)
+ return to_return
+
+ try:
+ to_return = cls[name.upper()]
+ log_message(LoggingScope.TASK_OPS, 'INFO', "from_string will return: %s", to_return)
+ return to_return
+ except KeyError:
+ return default
+
+ def __lt__(self, other):
+ if self.__class__ is other.__class__:
+ return self.value < other.value
+ return NotImplemented
+
+ def __str__(self):
+ return self.name.upper()
+
+
+class EESSITask:
+ description: EESSITaskDescription
+ payload: EESSITaskPayload
+ action: EESSITaskAction
+ git_repo: Github
+ config: Dict
+
+ @log_function_entry_exit()
+ def __init__(self, description: EESSITaskDescription, config: Dict, cvmfs_repo: str, git_repo: Github):
+ self.description = description
+ self.config = config
+ self.cvmfs_repo = cvmfs_repo
+ self.git_repo = git_repo
+ self.action = self._determine_task_action()
+
+ # Define valid state transitions for all actions
+ # NOTE, TaskState.APPROVED must be the first element or _next_state() will not work
+ self.valid_transitions = {
+ TaskState.UNDETERMINED: [TaskState.NEW_TASK, TaskState.PAYLOAD_STAGED, TaskState.PULL_REQUEST,
+ TaskState.APPROVED, TaskState.REJECTED, TaskState.INGESTED, TaskState.DONE],
+ TaskState.NEW_TASK: [TaskState.PAYLOAD_STAGED],
+ TaskState.PAYLOAD_STAGED: [TaskState.PULL_REQUEST],
+ TaskState.PULL_REQUEST: [TaskState.APPROVED, TaskState.REJECTED],
+ TaskState.APPROVED: [TaskState.INGESTED],
+ TaskState.REJECTED: [], # terminal state
+ TaskState.INGESTED: [], # terminal state
+ TaskState.DONE: [] # virtual terminal state, not used to write on GitHub
+ }
+
+ self.payload = None
+ state = self.determine_state()
+ if state >= TaskState.PAYLOAD_STAGED:
+ log_message(LoggingScope.TASK_OPS, 'INFO', "initializing payload object in constructor for EESSITask")
+ self._init_payload_object()
+
+ @log_function_entry_exit()
+ def _determine_task_action(self) -> EESSITaskAction:
+ """
+ Determine the action type based on task description metadata.
+ """
+ if 'task' in self.description.metadata and 'action' in self.description.metadata['task']:
+ action_str = self.description.metadata['task']['action'].lower()
+ if action_str == "nop":
+ return EESSITaskAction.NOP
+ elif action_str == "delete":
+ return EESSITaskAction.DELETE
+ elif action_str == "add":
+ return EESSITaskAction.ADD
+ elif action_str == "update":
+ return EESSITaskAction.UPDATE
+ return EESSITaskAction.UNKNOWN
+
+ @log_function_entry_exit()
+ def _state_file_with_prefix_exists_in_repo_branch(self, file_path_prefix: str, branch_name: str = None) -> bool:
+ """
+ Check if a file exists in a repository branch.
+
+ Args:
+ file_path_prefix: the prefix of the file path
+ branch_name: the branch to check
+
+ Returns:
+ True if a file with the prefix exists in the branch, False otherwise
+ """
+ branch_name = self.git_repo.default_branch if branch_name is None else branch_name
+ # branch = self._get_branch_from_name(branch_name)
+ try:
+ # get all files in directory part of file_path_prefix
+ directory_part = os.path.dirname(file_path_prefix)
+ files = self.git_repo.get_contents(directory_part, ref=branch_name)
+ log_msg = "Found files %s in directory %s in branch %s"
+ log_message(LoggingScope.TASK_OPS, 'INFO', log_msg, files, directory_part, branch_name)
+ # check if any of the files has file_path_prefix as prefix
+ for file in files:
+ if file.path.startswith(file_path_prefix):
+ log_msg = "Found file %s in directory %s in branch %s"
+ log_message(LoggingScope.TASK_OPS, 'INFO', log_msg, file.path, directory_part, branch_name)
+ return True
+ log_msg = "No file with prefix %s found in directory %s in branch %s"
+ log_message(LoggingScope.TASK_OPS, 'INFO', log_msg, file_path_prefix, directory_part, branch_name)
+ return False
+ except UnknownObjectException:
+ # file_path does not exist in branch
+ log_msg = "Directory %s or file with prefix %s does not exist in branch %s"
+ log_message(LoggingScope.TASK_OPS, 'INFO', log_msg, directory_part, file_path_prefix, branch_name)
+ return False
+ except GithubException as err:
+ if err.status == 404:
+ # file_path does not exist in branch
+ log_msg = "Directory %s or file with prefix %s does not exist in branch %s"
+ log_message(LoggingScope.TASK_OPS, 'INFO', log_msg, directory_part, file_path_prefix, branch_name)
+ return False
+ else:
+ # if there was some other (e.g. connection) issue, log message and return False
+ log_msg = 'Unable to determine the state of %s, the GitHub API returned status %s!'
+ log_message(LoggingScope.ERROR, 'WARNING', log_msg, self.object, err.status)
+ return False
+ return False
+
+ @log_function_entry_exit()
+ def _determine_sequence_numbers_including_task_file(self, repo: str, pr: str) -> Dict[int, bool]:
+ """
+ Determines in which sequence numbers the metadata/task file is included and in which it is not.
+ NOTE, we only need to check the default branch of the repository, because a for a new task a file
+ is added to the default branch and for the subsequent processing of the task we use a different branch.
+ Thus, until the PR is closed, the task file stays in the default branch.
+
+ Args:
+ repo: the repository name
+ pr: the pull request number
+
+ Returns:
+ A dictionary with the sequence numbers as keys and a boolean value indicating if the metadata/task file is
+ included in that sequence number.
+
+ Idea:
+ - The deployment for a single source PR could be split into multiple staging PRs each is assigned a unique
+ sequence number.
+ - For a given source PR (identified by the repo name and the PR number), a staging PR using a branch named
+ `REPO/PR_NUM/SEQ_NUM` is created.
+ - In the staging repo we create a corresponding directory `REPO/PR_NUM/SEQ_NUM`.
+ - If a metadata/task file is handled by the staging PR with sequence number, it is included in that directory.
+ - We iterate over all directories under `REPO/PR_NUM`:
+ - If the metadata/task file is available in the directory, we add the sequence number to the list.
+
+ Note: this is a placeholder for now, as we do not know yet if we need to use a sequence number.
+ """
+ sequence_numbers = {}
+ repo_pr_dir = f"{repo}/{pr}"
+ # iterate over all directories under repo_pr_dir
+ try:
+ directories = self._list_directory_contents(repo_pr_dir)
+ for dir in directories:
+ # check if the directory is a number
+ if dir.name.isdigit():
+ # determine if a state file with prefix exists in the sequence number directory
+ # we need to use the basename of the remote file path
+ remote_file_path_basename = os.path.basename(self.description.task_object.remote_file_path)
+ state_file_name_prefix = f"{repo_pr_dir}/{dir.name}/{remote_file_path_basename}"
+ if self._state_file_with_prefix_exists_in_repo_branch(state_file_name_prefix):
+ sequence_numbers[int(dir.name)] = True
+ else:
+ sequence_numbers[int(dir.name)] = False
+ else:
+ # directory is not a number, so we skip it
+ continue
+ except FileNotFoundError:
+ # repo_pr_dir does not exist, so we return an empty dictionary
+ return {}
+ except GithubException as err:
+ if err.status != 404: # 404 is catched by FileNotFoundError
+ # some other error than the directory not existing
+ return {}
+ return sequence_numbers
+
+ @log_function_entry_exit()
+ def _find_highest_number(self, str_list: List[str]) -> int:
+ """
+ Find the highest number in a list of strings.
+ """
+ # Convert all strings to integers
+ int_list = [int(num) for num in str_list]
+ return max(int_list)
+
+ @log_function_entry_exit()
+ def _get_sequence_number_for_task_file(self) -> int:
+ """
+ Get the sequence number this task is assigned to at the moment.
+ NOTE, should only be called if the task is actually assigned to a sequence number.
+ """
+ repo_name = self.description.get_repo_name()
+ pr_number = self.description.get_pr_number()
+ sequence_numbers = self._determine_sequence_numbers_including_task_file(repo_name, pr_number)
+ if len(sequence_numbers) == 0:
+ raise ValueError("Found no sequence numbers at all")
+ else:
+ # get all entries with value True, there should be only one, so we return the first one
+ sequence_numbers_true = [key for key, value in sequence_numbers.items() if value is True]
+ if len(sequence_numbers_true) == 0:
+ raise ValueError("Found no sequence numbers that include the task file for task %s",
+ self.description)
+ else:
+ return sequence_numbers_true[0]
+
+ @log_function_entry_exit()
+ def _get_current_sequence_number(self, sequence_numbers: Dict[int, bool] = None) -> int:
+ """
+ Get the current sequence number based on the sequence numbers.
+ If sequence_numbers is not provided, we determine the sequence numbers from the task description.
+ """
+ if sequence_numbers is None:
+ repo_name = self.description.get_repo_name()
+ pr_number = self.description.get_pr_number()
+ sequence_numbers = self._determine_sequence_numbers_including_task_file(repo_name, pr_number)
+ if len(sequence_numbers) == 0:
+ return 0
+ return self._find_highest_number(sequence_numbers.keys())
+
+ @log_function_entry_exit()
+ def _get_fixed_sequence_number(self) -> int:
+ """
+ Get a fixed sequence number.
+ """
+ return 11
+
+ @log_function_entry_exit()
+ def _determine_sequence_status(self, sequence_number: int = None) -> int:
+ """
+ Determine the status of the sequence number. It could be: DOES_NOT_EXIST, IN_PROGRESS, FINISHED
+ If sequence_number is not provided, we use the highest existing sequence number.
+ """
+ if sequence_number is None:
+ sequence_number = self._get_current_sequence_number()
+ repo_name = self.description.get_repo_name()
+ pr_number = self.description.get_pr_number()
+ sequence_numbers = self._determine_sequence_numbers_including_task_file(repo_name, pr_number)
+ if len(sequence_numbers) == 0:
+ return SequenceStatus.DOES_NOT_EXIST
+ elif sequence_number not in sequence_numbers.keys():
+ return SequenceStatus.DOES_NOT_EXIST
+ elif sequence_number < self._find_highest_number(sequence_numbers.keys()):
+ return SequenceStatus.FINISHED
+ else:
+ # check status of PR if it exists
+ branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}"
+ if branch_name in [branch.name for branch in self.git_repo.get_branches()]:
+ find_pr = [pr for pr in self.git_repo.get_pulls(head=branch_name, state='all')]
+ if find_pr:
+ pr = find_pr.pop(0)
+ if pr.state == 'closed':
+ return SequenceStatus.FINISHED
+ return SequenceStatus.IN_PROGRESS
+
+ @log_function_entry_exit()
+ def _find_staging_pr(self) -> Tuple[Optional[PullRequest], Optional[str], Optional[int]]:
+ """
+ Find the staging PR for the task.
+ TODO: arg sequence number --> make function simpler
+ """
+ repo_name = self.description.get_repo_name()
+ pr_number = self.description.get_pr_number()
+ try:
+ sequence_number = self._get_sequence_number_for_task_file()
+ except ValueError:
+ # no sequence number found, so we return None
+ log_message(LoggingScope.ERROR, 'ERROR', "no sequence number found for task %s", self.description)
+ return None, None, None
+ except Exception as err:
+ # some other error
+ log_message(LoggingScope.ERROR, 'ERROR', "error finding staging PR for task %s: %s",
+ self.description, err)
+ return None, None, None
+ branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}"
+ if branch_name in [branch.name for branch in self.git_repo.get_branches()]:
+ find_pr = [pr for pr in self.git_repo.get_pulls(head=branch_name, state='all')]
+ if find_pr:
+ pr = find_pr.pop(0)
+ return pr, branch_name, sequence_number
+ else:
+ return None, branch_name, sequence_number
+ else:
+ return None, None, None
+
+ @log_function_entry_exit()
+ def _create_staging_pr(self, sequence_number: int) -> Tuple[PullRequest, str]:
+ """
+ Create a staging PR for the task.
+ NOTE, SHALL only be called if no staging PR for the task exists yet.
+ """
+ repo_name = self.description.get_repo_name()
+ pr_number = self.description.get_pr_number()
+ branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}"
+ default_branch_name = self.git_repo.default_branch
+ pr = self.git_repo.create_pull(title=f"Add task for {repo_name} PR {pr_number} seq {sequence_number}",
+ body=f"Add task for {repo_name} PR {pr_number} seq {sequence_number}",
+ head=branch_name, base=default_branch_name)
+ return pr, branch_name
+
+ @log_function_entry_exit()
+ def _find_state(self) -> TaskState:
+ """
+ Determine the state of the task based on the task description metadata.
+
+ Returns:
+ The state of the task.
+ """
+ # obtain repo and pr from metadata
+ log_message(LoggingScope.TASK_OPS, 'INFO', "finding state of task %s", self.description.task_object)
+ repo = self.description.get_repo_name()
+ pr = self.description.get_pr_number()
+ log_message(LoggingScope.TASK_OPS, 'INFO', "repo: %s, pr: %s", repo, pr)
+
+ # obtain all sequence numbers in repo/pr dir which include a state file for this task
+ sequence_numbers = self._determine_sequence_numbers_including_task_file(repo, pr)
+ if len(sequence_numbers) == 0:
+ # no sequence numbers found, so we return NEW_TASK
+ log_message(LoggingScope.TASK_OPS, 'INFO', "no sequence numbers found, state: NEW_TASK")
+ return TaskState.NEW_TASK
+ # we got at least one sequence number
+ # if one value for a sequence number is True, we can determine the state from the file in the directory
+ sequence_including_task = [key for key, value in sequence_numbers.items() if value is True]
+ if len(sequence_including_task) == 0:
+ # no sequence number includes the task file, so we return NEW_TASK
+ log_message(LoggingScope.TASK_OPS, 'INFO', "no sequence number includes the task file, state: NEW_TASK")
+ return TaskState.NEW_TASK
+ # we got at least one sequence number which includes the task file
+ # we can determine the state from the filename in the directory
+ # NOTE, we use the first element in sequence_including_task (there should be only one)
+ # we ignore other elements in sequence_including_task
+ sequence_number = sequence_including_task[0]
+ task_file_name = self.description.get_task_file_name()
+ metadata_file_state_path_prefix = f"{repo}/{pr}/{sequence_number}/{task_file_name}."
+ state = self._get_state_for_metadata_file_prefix(metadata_file_state_path_prefix, sequence_number)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "state: %s", state)
+ return state
+
+ @log_function_entry_exit()
+ def _get_state_for_metadata_file_prefix(self, metadata_file_state_path_prefix: str,
+ sequence_number: int) -> TaskState:
+ """
+ Get the state from the file in the metadata_file_state_path_prefix.
+ """
+ # depending on the state of the deployment (NEW_TASK, PAYLOAD_STAGED, PULL_REQUEST, APPROVED, REJECTED,
+ # INGESTED, DONE)
+ # we need to check the task file in the default branch or in the branch corresponding to the sequence number
+ directory_part = os.path.dirname(metadata_file_state_path_prefix)
+ repo_name = self.description.get_repo_name()
+ pr_number = self.description.get_pr_number()
+ default_branch_name = self.git_repo.default_branch
+ branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}"
+ all_branch_names = [branch.name for branch in self.git_repo.get_branches()]
+ states = []
+ for branch in [default_branch_name, branch_name]:
+ if branch in all_branch_names:
+ # first get all files in directory part of metadata_file_state_path_prefix
+ files = self._list_directory_contents(directory_part, branch)
+ # check if any of the files has metadata_file_state_path_prefix as prefix
+ for file in files:
+ if file.path.startswith(metadata_file_state_path_prefix):
+ # get state from file name taking only the suffix
+ state = TaskState.from_string(file.name.split('.')[-1])
+ log_message(LoggingScope.TASK_OPS, 'INFO', "state: %s", state)
+ states.append(state)
+ if len(states) == 0:
+ # did not find any file with metadata_file_state_path_prefix as prefix
+ log_message(LoggingScope.TASK_OPS, 'INFO', "did not find any file with prefix %s",
+ metadata_file_state_path_prefix)
+ return TaskState.NEW_TASK
+ # sort the states and return the last one
+ states.sort()
+ state = states[-1]
+ log_message(LoggingScope.TASK_OPS, 'INFO', "state: %s", state)
+ return state
+
+ @log_function_entry_exit()
+ def _list_directory_contents(self, directory_path, branch_name: str = None):
+ try:
+ # Get contents of the directory
+ branch_name = self.git_repo.default_branch if branch_name is None else branch_name
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "listing contents of %s in branch %s", directory_path, branch_name)
+ contents = self.git_repo.get_contents(directory_path, ref=branch_name)
+
+ # If contents is a list, it means we successfully got directory contents
+ if isinstance(contents, list):
+ return contents
+ else:
+ # If it's not a list, it means the path is not a directory
+ raise ValueError(f"{directory_path} is not a directory")
+ except GithubException as err:
+ if err.status == 404:
+ raise FileNotFoundError(f"Directory not found: {directory_path}")
+ raise err
+
+ @log_function_entry_exit()
+ def _next_state(self, state: TaskState = None) -> TaskState:
+ """
+ Determine the next state based on the current state using the valid_transitions dictionary.
+
+ NOTE, it assumes that function is only called for non-terminal states and that the next state is the first
+ element of the list returned by the valid_transitions dictionary.
+ """
+ the_state = state if state is not None else self.determine_state()
+ return self.valid_transitions[the_state][0]
+
+ @log_function_entry_exit()
+ def _path_exists_in_branch(self, path: str, branch_name: str = None) -> bool:
+ """
+ Check if a path exists in a branch.
+ """
+ branch_name = self.git_repo.default_branch if branch_name is None else branch_name
+ try:
+ self.git_repo.get_contents(path, ref=branch_name)
+ return True
+ except GithubException as err:
+ if err.status == 404:
+ return False
+ else:
+ raise err
+
+ @log_function_entry_exit()
+ def _read_dict_from_string(self, content: str) -> dict:
+ """
+ Read the dictionary from the string.
+ """
+ config_dict = {}
+ for line in content.strip().split('\n'):
+ if '=' in line and not line.strip().startswith('#'): # Skip comments
+ key, value = line.split('=', 1) # Split only on first '='
+ config_dict[key.strip()] = value.strip()
+ return config_dict
+
+ @log_function_entry_exit()
+ def _read_pull_request_dir_from_file(self, task_pointer_file: str = None, branch_name: str = None) -> str:
+ """
+ Read the pull request directory from the file in the given branch.
+ """
+ # set default values for task pointer file and branch name
+ if task_pointer_file is None:
+ task_pointer_file = self.description.task_object.remote_file_path
+ if branch_name is None:
+ branch_name = self.git_repo.default_branch
+ log_message(LoggingScope.TASK_OPS, 'INFO', "reading pull request directory from file '%s' in branch '%s'",
+ task_pointer_file, branch_name)
+
+ # read the pull request directory from the file in the given branch
+ content = self.git_repo.get_contents(task_pointer_file, ref=branch_name)
+
+ # Decode the content from base64
+ content_str = content.decoded_content.decode('utf-8')
+
+ # Parse into dictionary
+ config_dict = self._read_dict_from_string(content_str)
+
+ target_dir = config_dict.get('target_dir', None)
+ return config_dict.get('pull_request_dir', target_dir)
+
+ @log_function_entry_exit()
+ def _determine_pull_request_dir(self, task_pointer_file: str = None, branch_name: str = None) -> str:
+ """Determine the pull request directory via the task pointer file"""
+ return self._read_pull_request_dir_from_file(task_pointer_file=task_pointer_file, branch_name=branch_name)
+
+ @log_function_entry_exit()
+ def _get_branch_from_name(self, branch_name: str = None) -> Optional[Branch]:
+ """
+ Get a branch object from its name.
+ """
+ branch_name = self.git_repo.default_branch if branch_name is None else branch_name
+
+ try:
+ branch = self.git_repo.get_branch(branch_name)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "branch %s exists: %s", branch_name, branch)
+ return branch
+ except Exception as err:
+ log_message(LoggingScope.TASK_OPS, 'ERROR', "error checking if branch %s exists: %s",
+ branch_name, err)
+ return None
+
+ @log_function_entry_exit()
+ def _read_task_state_from_file(self, path: str, branch_name: str = None) -> TaskState:
+ """
+ Read the task state from the file in the given branch.
+ """
+ branch_name = self.git_repo.default_branch if branch_name is None else branch_name
+ content = self.git_repo.get_contents(path, ref=branch_name)
+
+ # Decode the content from base64
+ content_str = content.decoded_content.decode('utf-8').strip()
+ log_message(LoggingScope.TASK_OPS, 'INFO', "content in TaskState file: %s", content_str)
+
+ task_state = TaskState.from_string(content_str)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "task state: %s", task_state)
+
+ return task_state
+
+ @log_function_entry_exit()
+ def determine_state(self, branch: str = None) -> TaskState:
+ """
+ Determine the state of the task based on the state of the staging repository.
+ """
+ # check if path representing the task file exists in the default branch or the "feature" branch
+ task_pointer_file = self.description.task_object.remote_file_path
+ branch_to_use = self.git_repo.default_branch if branch is None else branch
+
+ if self._path_exists_in_branch(task_pointer_file, branch_name=branch_to_use):
+ log_message(LoggingScope.TASK_OPS, 'INFO', "path '%s' exists in branch '%s'",
+ task_pointer_file, branch_to_use)
+
+ # get state from task file in branch to use
+ # - read the TaskState file in pull request directory
+ pull_request_dir = self._determine_pull_request_dir(branch_name=branch_to_use)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "pull request directory: '%s'", pull_request_dir)
+ task_state_file_path = f"{pull_request_dir}/TaskState"
+ log_message(LoggingScope.TASK_OPS, 'INFO', "task state file path: '%s'", task_state_file_path)
+ task_state = self._read_task_state_from_file(task_state_file_path, branch_to_use)
+
+ log_message(LoggingScope.TASK_OPS, 'INFO', "task state in branch '%s': %s",
+ branch_to_use, task_state)
+ return task_state
+ else:
+ log_message(LoggingScope.TASK_OPS, 'INFO', "path '%s' does not exist in branch '%s'",
+ task_pointer_file, branch_to_use)
+ return TaskState.UNDETERMINED
+
+ @log_function_entry_exit()
+ def handle(self):
+ """
+ Dynamically find and execute the appropriate handler based on action and state.
+ """
+ state_before_handle = self.determine_state()
+
+ # Construct handler method name
+ handler_name = f"_handle_{self.action}_{str(state_before_handle).lower()}"
+
+ # Check if the handler exists
+ handler = getattr(self, handler_name, None)
+
+ if handler and callable(handler):
+ # Execute the handler if it exists
+ return handler()
+ else:
+ # Default behavior for missing handlers
+ log_message(LoggingScope.TASK_OPS, 'ERROR',
+ "No handler for action %s and state %s implemented; nothing to be done",
+ self.action, state_before_handle)
+ return state_before_handle
+
+ # Implement handlers for ADD action
+ @log_function_entry_exit()
+ def _safe_create_file(self, path: str, message: str, content: str, branch_name: str = None):
+ """Create a file in the given branch."""
+ try:
+ branch_name = self.git_repo.default_branch if branch_name is None else branch_name
+ existing_file = self.git_repo.get_contents(path, ref=branch_name)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "File %s already exists", path)
+ return existing_file
+ except GithubException as err:
+ if err.status == 404: # File doesn't exist
+ # Safe to create
+ return self.git_repo.create_file(path, message, content, branch=branch_name)
+ else:
+ raise err # Some other error
+
+ @log_function_entry_exit()
+ def _create_multi_file_commit(self, files_data, commit_message, branch_name: str = None):
+ """
+ Create a commit with multiple file changes
+
+ files_data: dict with structure:
+ {
+ "path/to/file1.txt": {
+ "content": "file content",
+ "mode": "100644" # optional, defaults to 100644
+ },
+ "path/to/file2.py": {
+ "content": "print('hello')",
+ "mode": "100644"
+ }
+ }
+ """
+ branch_name = self.git_repo.default_branch if branch_name is None else branch_name
+ ref = self.git_repo.get_git_ref(f"heads/{branch_name}")
+ current_commit = self.git_repo.get_git_commit(ref.object.sha)
+ base_tree = current_commit.tree
+
+ # Create tree elements
+ tree_elements = []
+ for file_path, file_info in files_data.items():
+ content = file_info["content"]
+ if isinstance(content, str):
+ content = content.encode('utf-8')
+
+ blob = self.git_repo.create_git_blob(
+ base64.b64encode(content).decode('utf-8'),
+ "base64"
+ )
+ tree_elements.append(InputGitTreeElement(
+ path=file_path,
+ mode=file_info.get("mode", "100644"),
+ type="blob",
+ sha=blob.sha
+ ))
+
+ # Create new tree
+ new_tree = self.git_repo.create_git_tree(tree_elements, base_tree)
+
+ # Create commit
+ new_commit = self.git_repo.create_git_commit(
+ commit_message,
+ new_tree,
+ [current_commit]
+ )
+
+ # Update branch reference
+ ref.edit(new_commit.sha)
+
+ return new_commit
+
+ @log_function_entry_exit()
+ def _update_file(self, file_path, new_content, commit_message, branch_name: str = None) -> Optional[Dict]:
+ try:
+ branch_name = self.git_repo.default_branch if branch_name is None else branch_name
+
+ # Get the current file
+ file = self.git_repo.get_contents(file_path, ref=branch_name)
+
+ # Update the file
+ result = self.git_repo.update_file(
+ path=file_path,
+ message=commit_message,
+ content=new_content,
+ sha=file.sha,
+ branch=branch_name
+ )
+
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "File updated successfully. Commit SHA: %s", result['commit'].sha)
+ return result
+
+ except Exception as err:
+ log_message(LoggingScope.TASK_OPS, 'ERROR', "Error updating file: %s", err)
+ return None
+
+ @log_function_entry_exit()
+ def _sorted_list_of_sequence_numbers(self) -> List[int]:
+ """Create a sorted list of sequence numbers from the pull requests directory"""
+ # a pull request's directory is of the form REPO/PR/SEQ
+ # hence, we can get all sequence numbers from the pull requests directory REPO/PR
+ sequence_numbers = []
+ repo_pr_dir = f"{self.description.get_repo_name()}/{self.description.get_pr_number()}"
+
+ # iterate over all directories under repo_pr_dir
+ try:
+ directories = self._list_directory_contents(repo_pr_dir)
+ for dir in directories:
+ # check if the directory is a number
+ if dir.name.isdigit():
+ sequence_numbers.append(int(dir.name))
+ else:
+ # directory is not a number, so we skip it
+ continue
+ except FileNotFoundError:
+ # repo_pr_dir does not exist, so we return an empty dictionary
+ log_message(LoggingScope.TASK_OPS, 'ERROR', "Pull requests directory '%s' does not exist", repo_pr_dir)
+ except GithubException as err:
+ if err.status != 404: # 404 is catched by FileNotFoundError
+ # some other error than the directory not existing
+ log_message(LoggingScope.TASK_OPS, 'ERROR',
+ "Some other error than the directory not existing: %s", err)
+ except Exception as err:
+ log_message(LoggingScope.TASK_OPS, 'ERROR', "Unexpected error: %s", err)
+
+ return sorted(sequence_numbers)
+
+ @log_function_entry_exit()
+ def _determine_sequence_number(self) -> int:
+ """Determine the sequence number for the task"""
+
+ sequence_numbers = self._sorted_list_of_sequence_numbers()
+ log_message(LoggingScope.TASK_OPS, 'INFO', "number of sequence numbers: %d", len(sequence_numbers))
+ if len(sequence_numbers) == 0:
+ return 0
+
+ log_message(LoggingScope.TASK_OPS, 'INFO', "sequence numbers: [%s]", ", ".join(map(str, sequence_numbers)))
+
+ # get the highest sequence number
+ highest_sequence_number = sequence_numbers[-1]
+ log_message(LoggingScope.TASK_OPS, 'INFO', "highest sequence number: %d", highest_sequence_number)
+
+ pull_request = self._find_pr_for_sequence_number(highest_sequence_number)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "pull request: %s", pull_request)
+
+ if pull_request is None:
+ log_message(LoggingScope.TASK_OPS, 'INFO', "Did not find pull request for sequence number %d",
+ highest_sequence_number)
+ # the directory for the sequence number exists but no PR yet
+ return highest_sequence_number
+ else:
+ log_message(LoggingScope.TASK_OPS, 'INFO', "pull request found: %s", pull_request)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "pull request state/merged: %s/%s",
+ pull_request.state, str(pull_request.is_merged()))
+ if pull_request.is_merged():
+ # the PR is merged, so we use the next sequence number
+ return highest_sequence_number + 1
+ else:
+ # the PR is not merged, so we can use the current sequence number
+ return highest_sequence_number
+
+ @log_function_entry_exit()
+ def _handle_add_undetermined(self):
+ """Handler for ADD action in UNDETERMINED state"""
+ print("Handling ADD action in UNDETERMINED state: %s" % self.description.get_task_file_name())
+ # task is in state UNDETERMINED if there is no pull request directory for the task yet
+ #
+ # create pull request directory (REPO/PR/SEQ/TASK_FILE_NAME/)
+ # create task file in pull request directory (PULL_REQUEST_DIR/TaskDescription)
+ # create task status file in pull request directory (PULL_REQUEST_DIR/TaskState.NEW_TASK)
+ # create pointer file from task file path to pull request directory (remote_file_path -> PULL_REQUEST_DIR)
+ repo_name = self.description.get_repo_name()
+ pr_number = self.description.get_pr_number()
+ sequence_number = self._determine_sequence_number() # corresponds to an open or yet to be created PR
+ task_file_name = self.description.get_task_file_name()
+ # we cannot use self._determine_pull_request_dir() here because it requires a task pointer file
+ # and we don't have one yet
+ pull_request_dir = f"{repo_name}/{pr_number}/{sequence_number}/{task_file_name}"
+ task_description_file_path = f"{pull_request_dir}/TaskDescription"
+ task_state_file_path = f"{pull_request_dir}/TaskState"
+ remote_file_path = self.description.task_object.remote_file_path
+
+ files_to_commit = {
+ task_description_file_path: {
+ "content": self.description.get_contents(),
+ "mode": "100644"
+ },
+ task_state_file_path: {
+ "content": f"{TaskState.NEW_TASK.name}\n",
+ "mode": "100644"
+ },
+ remote_file_path: {
+ "content": f"remote_file_path = {remote_file_path}\npull_request_dir = {pull_request_dir}",
+ "mode": "100644"
+ }
+ }
+
+ branch_name = self.git_repo.default_branch
+ try:
+ commit = self._create_multi_file_commit(
+ files_to_commit,
+ f"new task for {repo_name} PR {pr_number} seq {sequence_number}",
+ branch_name=branch_name
+ )
+ log_message(LoggingScope.TASK_OPS, 'INFO', "commit created: %s", commit)
+ except Exception as err:
+ log_message(LoggingScope.TASK_OPS, 'ERROR', "Error creating commit: %s", err)
+ # TODO: rollback previous changes (task description file, task state file)
+ return TaskState.UNDETERMINED
+
+ # TODO: verify that the sequence number is still valid (PR corresponding to the sequence number
+ # is still open or yet to be created); if it is not valid, perform corrective actions
+ return TaskState.NEW_TASK
+
+ @log_function_entry_exit()
+ def _update_task_state_file(self, next_state: TaskState, branch_name: str = None) -> Optional[Dict]:
+ """Update the TaskState file content in default or given branch"""
+ branch_name = self.git_repo.default_branch if branch_name is None else branch_name
+
+ task_pointer_file = self.description.task_object.remote_file_path
+ pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, branch_name)
+ task_state_file_path = f"{pull_request_dir}/TaskState"
+ arch = self.description.get_metadata_file_components()[3]
+ commit_message = f"change task state to {next_state} in {branch_name} for {arch}"
+ result = self._update_file(task_state_file_path,
+ f"{next_state.name}\n",
+ commit_message,
+ branch_name=branch_name)
+ return result
+
+ @log_function_entry_exit()
+ def _init_payload_object(self):
+ """Initialize the payload object"""
+ if self.payload is not None:
+ log_message(LoggingScope.TASK_OPS, 'INFO', "payload object already initialized")
+ return
+
+ # get name of of payload from metadata
+ payload_name = self.description.metadata['payload']['filename']
+ log_message(LoggingScope.TASK_OPS, 'INFO', "payload_name: %s", payload_name)
+
+ # get config and remote_client from self.description.task_object
+ config = self.description.task_object.config
+ remote_client = self.description.task_object.remote_client
+
+ # determine remote_file_path by replacing basename of remote_file_path in self.description.task_object
+ # with payload_name
+ description_remote_file_path = self.description.task_object.remote_file_path
+ payload_remote_file_path = os.path.join(os.path.dirname(description_remote_file_path), payload_name)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "payload_remote_file_path: %s", payload_remote_file_path)
+
+ # initialize payload object
+ payload_object = EESSIDataAndSignatureObject(config, payload_remote_file_path, remote_client)
+ self.payload = EESSITaskPayload(payload_object)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "payload: %s", self.payload)
+
+ @log_function_entry_exit()
+ def _handle_add_new_task(self):
+ """Handler for ADD action in NEW_TASK state"""
+ print("Handling ADD action in NEW_TASK state: %s" % self.description.get_task_file_name())
+ # determine next state
+ next_state = self._next_state(TaskState.NEW_TASK)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "next_state: %s", next_state)
+
+ # initialize payload object
+ self._init_payload_object()
+
+ # update TaskState file content
+ self._update_task_state_file(next_state)
+
+ # TODO: verify that the sequence number is still valid (PR corresponding to the sequence number
+ # is still open or yet to be created); if it is not valid, perform corrective actions
+ return next_state
+
+ @log_function_entry_exit()
+ def _find_pr_for_branch(self, branch_name: str) -> Optional[PullRequest]:
+ """
+ Find the single PR for the given branch in any state.
+
+ Args:
+ repo: GitHub repository
+ branch_name: Name of the branch
+
+ Returns:
+ PullRequest object if found, None otherwise
+ """
+ try:
+ prs = [pr for pr in list(self.git_repo.get_pulls(state='all'))
+ if pr.head.ref == branch_name]
+ log_message(LoggingScope.TASK_OPS, 'INFO', "number of PRs found: %d", len(prs))
+ if len(prs):
+ log_message(LoggingScope.TASK_OPS, 'INFO', "1st PR found: %d, %s", prs[0].number, prs[0].head.ref)
+ return prs[0] if prs else None
+ except Exception as err:
+ log_message(LoggingScope.TASK_OPS, 'ERROR', "Error finding PR for branch %s: %s", branch_name, err)
+ return None
+
+ @log_function_entry_exit()
+ def _find_pr_for_sequence_number(self, sequence_number: int) -> Optional[PullRequest]:
+ """Find the PR for the given sequence number"""
+ repo_name = self.description.get_repo_name()
+ pr_number = self.description.get_pr_number()
+ feature_branch_name = f"{repo_name.replace('/', '-')}-PR-{pr_number}-SEQ-{sequence_number}"
+
+ # list all PRs with head_ref starting with the feature branch name without the sequence number
+ last_dash = feature_branch_name.rfind('-')
+ if last_dash != -1:
+ head_ref_wout_seq_num = feature_branch_name[:last_dash + 1] # +1 to include the separator
+ else:
+ head_ref_wout_seq_num = feature_branch_name
+
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "searching for PRs whose head_ref starts with: '%s'", head_ref_wout_seq_num)
+
+ all_prs = [pr for pr in list(self.git_repo.get_pulls(state='all'))
+ if pr.head.ref.startswith(head_ref_wout_seq_num)]
+ log_message(LoggingScope.TASK_OPS, 'INFO', " number of PRs found: %d", len(all_prs))
+ for pr in all_prs:
+ log_message(LoggingScope.TASK_OPS, 'INFO', " PR #%d: %s", pr.number, pr.head.ref)
+
+ # now, find the PR for the feature branch name (if any)
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "searching PR for feature branch name: '%s'", feature_branch_name)
+ pull_request = self._find_pr_for_branch(feature_branch_name)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "pull request for branch '%s': %s",
+ feature_branch_name, pull_request)
+ return pull_request
+
+ @log_function_entry_exit()
+ def _determine_sequence_number_from_pull_request_directory(self) -> int:
+ """Determine the sequence number from the pull request directory name"""
+ task_pointer_file = self.description.task_object.remote_file_path
+ pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, self.git_repo.default_branch)
+ # pull_request_dir is of the form REPO/PR/SEQ/TASK_FILE_NAME/ (REPO contains a '/' separating the org and repo)
+ _, _, _, seq, _ = pull_request_dir.split('/')
+ return int(seq)
+
+ @log_function_entry_exit()
+ def _determine_feature_branch_name(self) -> str:
+ """Determine the feature branch name from the pull request directory name"""
+ task_pointer_file = self.description.task_object.remote_file_path
+ pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, self.git_repo.default_branch)
+ # pull_request_dir is of the form REPO/PR/SEQ/TASK_FILE_NAME/ (REPO contains a '/' separating the org and repo)
+ org, repo, pr, seq, _ = pull_request_dir.split('/')
+ return f"{org}-{repo}-PR-{pr}-SEQ-{seq}"
+
+ @log_function_entry_exit()
+ def _sync_task_state_file(self, source_branch: str, target_branch: str):
+ """Update task state file from source to target branch"""
+ task_pointer_file = self.description.task_object.remote_file_path
+ pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, self.git_repo.default_branch)
+ task_state_file_path = f"{pull_request_dir}/TaskState"
+
+ try:
+ # Get content from source branch
+ source_content = self.git_repo.get_contents(task_state_file_path, ref=source_branch)
+
+ # Get current file in target branch
+ target_file = self.git_repo.get_contents(task_state_file_path, ref=target_branch)
+
+ # Update if content is different
+ if source_content.sha != target_file.sha:
+ result = self.git_repo.update_file(
+ path=task_state_file_path,
+ message=f"Sync {task_state_file_path} from {source_branch} to {target_branch}",
+ content=source_content.decoded_content,
+ sha=target_file.sha,
+ branch=target_branch
+ )
+ log_message(LoggingScope.TASK_OPS, 'INFO', "Updated %s", task_state_file_path)
+ return result
+ else:
+ log_message(LoggingScope.TASK_OPS, 'INFO', "No changes needed for %s", task_state_file_path)
+ return None
+
+ except Exception as err:
+ log_message(LoggingScope.TASK_OPS, 'ERROR', "Error syncing task state file: %s", err)
+ return None
+
+ @log_function_entry_exit()
+ def _update_task_states(self, next_state: TaskState, default_branch_name: str,
+ approved_state: TaskState, feature_branch_name: str):
+ """
+ Update task states in default and feature branches
+
+ States have to be updated in a specific order and in particular the default branch has to be
+ merged into the feature branch before the feature branch can be updated to avoid a merge conflict.
+
+ Args:
+ next_state: next state to be applied to the default branch
+ default_branch_name: name of the default branch
+ approved_state: state to be applied to the feature branch
+ feature_branch_name: name of the feature branch
+ """
+ # TODO: add failure handling (capture failures and return them somehow)
+
+ # update TaskState file content
+ # - next_state in default branch (interpreted as current state)
+ # - approved_state in feature branch (interpreted as future state, ie, after
+ # the PR corresponding to the feature branch will be merged)
+
+ # first, update the task state file in the default branch
+ self._update_task_state_file(next_state, branch_name=default_branch_name)
+
+ # second, merge default branch into feature branch (to avoid a merge conflict)
+ # TODO: store arch info (CPU+ACCEL) in task/metdata file and then access that rather
+ # than using a part of the file name
+ arch = self.description.get_metadata_file_components()[3]
+ commit_message = f"merge {default_branch_name} into {feature_branch_name} for {arch}"
+ self.git_repo.merge(
+ head=default_branch_name,
+ base=feature_branch_name,
+ commit_message=commit_message
+ )
+
+ # last, update task state file in feature branch
+ self._update_task_state_file(approved_state, branch_name=feature_branch_name)
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "TaskState file updated to %s in default branch (%s) and to %s in feature branch (%s)",
+ next_state, default_branch_name, approved_state, feature_branch_name)
+
+ @log_function_entry_exit()
+ def _create_task_summary(self) -> str:
+ """Analyse contents of current task and create a file for it in the REPO-PR-SEQ directory."""
+
+ # determine task summary file path in feature branch on GitHub
+ feature_branch_name = self._determine_feature_branch_name()
+ pull_request_dir = self._determine_pull_request_dir(branch_name=feature_branch_name)
+ task_summary_file_path = f"{pull_request_dir}/TaskSummary.html"
+
+ # check if task summary file already exists in repo on GitHub
+ if self._path_exists_in_branch(task_summary_file_path, feature_branch_name):
+ log_message(LoggingScope.TASK_OPS, 'INFO', "task summary file already exists: %s", task_summary_file_path)
+ task_summary = self.git_repo.get_contents(task_summary_file_path, ref=feature_branch_name)
+ # return task_summary.decoded_content
+ return task_summary
+
+ # create task summary
+ payload_name = self.description.metadata['payload']['filename']
+ payload_summary = self.payload.analyse_contents(self.config)
+ metadata_contents = self.description.get_contents()
+
+ task_summary = self.config['github']['task_summary_payload_template'].format(
+ payload_name=payload_name,
+ metadata_contents=metadata_contents,
+ payload_overview=payload_summary
+ )
+
+ # create HTML file with task summary in REPO-PR-SEQ directory
+ # TODO: add failure handling (capture result and act on it)
+ task_file_name = self.description.get_task_file_name()
+ commit_message = f"create summary for {task_file_name} in {feature_branch_name}"
+ self._safe_create_file(task_summary_file_path, commit_message, task_summary,
+ branch_name=feature_branch_name)
+ log_message(LoggingScope.TASK_OPS, 'INFO', "task summary file created: %s", task_summary_file_path)
+
+ # return task summary
+ return task_summary
+
+ @log_function_entry_exit()
+ def _create_pr_contents_overview(self) -> str:
+ """Create a contents overview for the pull request"""
+ # TODO: implement
+ feature_branch_name = self._determine_feature_branch_name()
+ task_pointer_file = self.description.task_object.remote_file_path
+ pull_request_dir = self._read_pull_request_dir_from_file(task_pointer_file, feature_branch_name)
+ pr_dir = os.path.dirname(pull_request_dir)
+ directories = self._list_directory_contents(pr_dir, feature_branch_name)
+ contents_overview = ""
+ if directories:
+ contents_overview += "\n"
+ for directory in directories:
+ task_summary_file_path = f"{pr_dir}/{directory.name}/TaskSummary.html"
+ if self._path_exists_in_branch(task_summary_file_path, feature_branch_name):
+ file_contents = self.git_repo.get_contents(task_summary_file_path, ref=feature_branch_name)
+ task_summary = base64.b64decode(file_contents.content).decode('utf-8')
+ contents_overview += f"{task_summary}\n"
+ else:
+ contents_overview += f"Task summary file not found: {task_summary_file_path}\n"
+ contents_overview += "\n"
+ else:
+ contents_overview += "No tasks found in this PR\n"
+
+ print(f"contents_overview: {contents_overview}")
+ return contents_overview
+
+ @log_function_entry_exit()
+ def _create_pull_request(self, feature_branch_name: str, default_branch_name: str):
+ """
+ Create a PR from the feature branch to the default branch
+
+ Args:
+ feature_branch_name: name of the feature branch
+ default_branch_name: name of the default branch
+ """
+ pr_title_format = self.config['github']['grouped_pr_title']
+ pr_body_format = self.config['github']['grouped_pr_body']
+ repo_name = self.description.get_repo_name()
+ pr_number = self.description.get_pr_number()
+ pr_url = f"https://github.com/{repo_name}/pull/{pr_number}"
+ seq_num = self._determine_sequence_number_from_pull_request_directory()
+ pr_title = pr_title_format.format(
+ cvmfs_repo=self.cvmfs_repo,
+ pr=pr_number,
+ repo=repo_name,
+ seq_num=seq_num,
+ )
+ self._create_task_summary()
+ contents_overview = self._create_pr_contents_overview()
+ pr_body = pr_body_format.format(
+ cvmfs_repo=self.cvmfs_repo,
+ pr=pr_number,
+ pr_url=pr_url,
+ repo=repo_name,
+ seq_num=seq_num,
+ contents=contents_overview,
+ analysis="
TO BE DONE",
+ action="TO BE DONE",
+ )
+ pr = self.git_repo.create_pull(
+ title=pr_title,
+ body=pr_body,
+ head=feature_branch_name,
+ base=default_branch_name
+ )
+ log_message(LoggingScope.TASK_OPS, 'INFO', "PR created: %s", pr)
+
+ @log_function_entry_exit()
+ def _update_pull_request(self, pull_request: PullRequest):
+ """
+ Update the pull request
+
+ Args:
+ pull_request: instance of the pull request
+ """
+ # TODO: update sections (contents analysis, action)
+ repo_name = self.description.get_repo_name()
+ pr_number = self.description.get_pr_number()
+ pr_url = f"https://github.com/{repo_name}/pull/{pr_number}"
+ seq_num = self._determine_sequence_number_from_pull_request_directory()
+
+ self._create_task_summary()
+ contents_overview = self._create_pr_contents_overview()
+ pr_body_format = self.config['github']['grouped_pr_body']
+ pr_body = pr_body_format.format(
+ cvmfs_repo=self.cvmfs_repo,
+ pr=pr_number,
+ pr_url=pr_url,
+ repo=repo_name,
+ seq_num=seq_num,
+ contents=contents_overview,
+ analysis="TO BE DONE",
+ action="TO BE DONE",
+ )
+ pull_request.edit(body=pr_body)
+
+ log_message(LoggingScope.TASK_OPS, 'INFO', "PR updated: %s", pull_request)
+
+ @log_function_entry_exit()
+ def _handle_add_payload_staged(self):
+ """Handler for ADD action in PAYLOAD_STAGED state"""
+ print("Handling ADD action in PAYLOAD_STAGED state: %s" % self.description.get_task_file_name())
+ next_state = self._next_state(TaskState.PAYLOAD_STAGED)
+ approved_state = TaskState.APPROVED
+ log_message(LoggingScope.TASK_OPS, 'INFO', "next_state: %s, approved_state: %s", next_state, approved_state)
+
+ default_branch_name = self.git_repo.default_branch
+ default_branch = self._get_branch_from_name(default_branch_name)
+ default_sha = default_branch.commit.sha
+ feature_branch_name = self._determine_feature_branch_name()
+ feature_branch = self._get_branch_from_name(feature_branch_name)
+ if not feature_branch:
+ # feature branch does not exist
+ # TODO: could have been merged already --> check if PR corresponding to the feature branch exists
+ # ASSUME: it has not existed before --> create it
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "branch %s does not exist, creating it", feature_branch_name)
+
+ feature_branch = self.git_repo.create_git_ref(f"refs/heads/{feature_branch_name}", default_sha)
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "branch %s created: %s", feature_branch_name, feature_branch)
+ else:
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "found existing branch for %s: %s", feature_branch_name, feature_branch)
+
+ pull_request = self._find_pr_for_branch(feature_branch_name)
+ if not pull_request:
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "no PR found for branch %s", feature_branch_name)
+
+ # TODO: add failure handling (capture result and act on it)
+ self._update_task_states(next_state, default_branch_name, approved_state, feature_branch_name)
+
+ # TODO: add failure handling (capture result and act on it)
+ self._create_pull_request(feature_branch_name, default_branch_name)
+
+ return TaskState.PULL_REQUEST
+ else:
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "found existing PR for branch %s: %s", feature_branch_name, pull_request)
+ # TODO: check if PR is open or closed
+ if pull_request.state == 'closed':
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "PR %s is closed, creating issue", pull_request)
+ # TODO: create issue
+ return TaskState.PAYLOAD_STAGED
+ else:
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "PR %s is open, updating task states", pull_request)
+ # TODO: add failure handling (capture result and act on it)
+ # THINK about what a failure would mean and what to do about it.
+ self._update_task_states(next_state, default_branch_name, approved_state, feature_branch_name)
+
+ # TODO: add failure handling (capture result and act on it)
+ self._update_pull_request(pull_request)
+
+ return TaskState.PULL_REQUEST
+
+ @log_function_entry_exit()
+ def _handle_add_pull_request(self):
+ """Handler for ADD action in PULL_REQUEST state"""
+ print("Handling ADD action in PULL_REQUEST state: %s" % self.description.get_task_file_name())
+ # Implementation for adding in PULL_REQUEST state
+ # we got here because the state of the task is PULL_REQUEST in the default branch
+ # determine branch and PR and state of PR
+ # PR is open --> just return TaskState.PULL_REQUEST
+ # PR is closed & merged --> deployment is approved
+ # PR is closed & not merged --> deployment is rejected
+ feature_branch_name = self._determine_feature_branch_name()
+ # TODO: check if feature branch exists, for now ASSUME it does
+ pull_request = self._find_pr_for_branch(feature_branch_name)
+ if pull_request:
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "found PR for branch %s: %s", feature_branch_name, pull_request)
+ if pull_request.state == 'closed':
+ if pull_request.merged:
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "PR %s is closed and merged, returning APPROVED state", pull_request)
+ # TODO: How could we ended up here? state in default branch is PULL_REQUEST but
+ # PR is merged, hence it should have been in the APPROVED state
+ # ==> for now, just return TaskState.PULL_REQUEST
+ #
+ # there is the possibility that the PR was updated just before the
+ # PR was merged
+ # WHY is it a problem? because a task may have been accepted that wouldn't
+ # have been accepted or worse shouldn't been accepted
+ # WHAT to do? ACCEPT/IGNORE THE ISSUE FOR NOw
+ # HOWEVER, the contents of the PR directory may be inconsistent with
+ # respect to the TaskState file and missing TaskSummary.html file
+ # WE could create an issue and only return TaskState.APPROVED if the
+ # issue is closed
+ # WE could also defer all handling of this to the handler for the
+ # APPROVED state
+ # NOPE, we have to do some handling here, at least for the tasks where their
+ # state file did
+ # --> check if we could have ended up here? If so, create an issue.
+ # Do we need a state ISSUE_OPENED to avoid processing the task again?
+ return TaskState.PULL_REQUEST
+ else:
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "PR %s is closed and not merged, returning REJECTED state", pull_request)
+ # TODO: there is the possibility that the PR was updated just before the
+ # PR was closed
+ # WHY is it a problem? because a task may have been rejected that wouldn't
+ # have been rejected or worse shouldn't been rejected
+ # WHAT to do? ACCEPT/IGNORE THE ISSUE FOR NOw
+ # HOWEVER, the contents of the PR directory may be inconsistent with
+ # respect to the TaskState file and missing TaskSummary.html file
+ # WE could create an issue and only return TaskState.REJECTED if the
+ # issue is closed
+ # WE could also defer all handling of this to the handler for the
+ # REJECTED state
+ # FOR NOW, we assume that the task was rejected on purpose
+ # we need to change the state of the task in the default branch to REJECTED
+ self._update_task_state_file(TaskState.REJECTED)
+ return TaskState.REJECTED
+ else:
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "PR %s is open, returning PULL_REQUEST state", pull_request)
+ return TaskState.PULL_REQUEST
+ else:
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ "no PR found for branch %s", feature_branch_name)
+ # the method was called because the state of the task is PULL_REQUEST in the default branch
+ # however, it's weird that the PR was not found for the feature branch
+ # TODO: may create or update an issue for the task or deployment
+ return TaskState.PULL_REQUEST
+
+ return TaskState.PULL_REQUEST
+
+ @log_function_entry_exit()
+ def _perform_task_action(self) -> bool:
+ """Perform the task action"""
+ # TODO: support other actions than ADD
+ if self.action == EESSITaskAction.ADD:
+ return self._perform_task_add()
+ else:
+ raise ValueError(f"Task action '{self.action}' not supported (yet)")
+
+ @log_function_entry_exit()
+ def _issue_exists(self, title: str, state: str = 'open') -> bool:
+ """
+ Check if an issue with the given title and state already exists.
+ """
+ issues = self.git_repo.get_issues(state=state)
+ for issue in issues:
+ if issue.title == title and issue.state == state:
+ return True
+ else:
+ return False
+
+ @log_function_entry_exit()
+ def _perform_task_add(self) -> bool:
+ """Perform the ADD task action"""
+ # TODO: verify checksum here or before?
+ script = self.config['paths']['ingestion_script']
+ sudo = ['sudo'] if self.config['cvmfs'].getboolean('ingest_as_root', True) else []
+ log_message(LoggingScope.STATE_OPS, 'INFO',
+ 'Running the ingestion script for %s...\n with script: %s\n with sudo: %s',
+ self.description.get_task_file_name(),
+ script, 'no' if sudo == [] else 'yes')
+ ingest_cmd = subprocess.run(
+ sudo + [script, self.cvmfs_repo, str(self.payload.payload_object.local_file_path)],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ log_message(LoggingScope.STATE_OPS, 'INFO',
+ 'Ingestion script returned code %s', ingest_cmd.returncode)
+ log_message(LoggingScope.STATE_OPS, 'INFO',
+ 'Ingestion script stdout: %s', ingest_cmd.stdout.decode('UTF-8'))
+ log_message(LoggingScope.STATE_OPS, 'INFO',
+ 'Ingestion script stderr: %s', ingest_cmd.stderr.decode('UTF-8'))
+ if ingest_cmd.returncode == 0:
+ next_state = self._next_state(TaskState.APPROVED)
+ self._update_task_state_file(next_state)
+ if self.config.has_section('slack') and self.config['slack'].getboolean('ingestion_notification', False):
+ send_slack_message(
+ self.config['secrets']['slack_webhook'],
+ self.config['slack']['ingestion_message'].format(
+ tarball=os.path.basename(self.payload.payload_object.local_file_path),
+ cvmfs_repo=self.cvmfs_repo)
+ )
+ return True
+ else:
+ tarball = os.path.basename(self.payload.payload_object.local_file_path)
+ log_message(LoggingScope.STATE_OPS, 'ERROR',
+ 'Failed to add %s, return code %s',
+ tarball,
+ ingest_cmd.returncode)
+
+ issue_title = f'Failed to add {tarball}'
+ log_message(LoggingScope.STATE_OPS, 'INFO',
+ "Creating issue for failed ingestion: title: '%s'",
+ issue_title)
+
+ command = ' '.join(ingest_cmd.args)
+ failed_ingestion_issue_body = self.config['github']['failed_ingestion_issue_body']
+ issue_body = failed_ingestion_issue_body.format(
+ command=command,
+ tarball=tarball,
+ return_code=ingest_cmd.returncode,
+ stdout=ingest_cmd.stdout.decode('UTF-8'),
+ stderr=ingest_cmd.stderr.decode('UTF-8')
+ )
+ log_message(LoggingScope.STATE_OPS, 'INFO',
+ "Creating issue for failed ingestion: body: '%s'",
+ issue_body)
+
+ if self._issue_exists(issue_title, state='open'):
+ log_message(LoggingScope.STATE_OPS, 'INFO',
+ 'Failed to add %s, but an open issue already exists, skipping...',
+ os.path.basename(self.payload.payload_object.local_file_path))
+ else:
+ log_message(LoggingScope.STATE_OPS, 'INFO',
+ 'Failed to add %s, but an open issue does not exist, creating one...',
+ os.path.basename(self.payload.payload_object.local_file_path))
+ self.git_repo.create_issue(title=issue_title, body=issue_body)
+ return False
+
+ @log_function_entry_exit()
+ def _handle_add_approved(self):
+ """Handler for ADD action in APPROVED state"""
+ print("Handling ADD action in APPROVED state: %s" % self.description.get_task_file_name())
+ # Implementation for adding in APPROVED state
+ # If successful, _perform_task_action() will change the state
+ # to INGESTED on GitHub
+ try:
+ if self._perform_task_action():
+ return TaskState.INGESTED
+ else:
+ return TaskState.APPROVED
+ except Exception as err:
+ log_message(LoggingScope.TASK_OPS, 'ERROR',
+ "Error performing task action: '%s'\nTraceback:\n%s", err, traceback.format_exc())
+ return TaskState.APPROVED
+
+ @log_function_entry_exit()
+ def _handle_add_ingested(self):
+ """Handler for ADD action in INGESTED state"""
+ print("Handling ADD action in INGESTED state: %s" % self.description.get_task_file_name())
+ # Implementation for adding in INGESTED state
+ # DONT change state on GitHub, because the result
+ # (INGESTED/REJECTED) would be overwritten
+ return TaskState.DONE
+
+ @log_function_entry_exit()
+ def _handle_add_rejected(self):
+ """Handler for ADD action in REJECTED state"""
+ print("Handling ADD action in REJECTED state: %s" % self.description.get_task_file_name())
+ # Implementation for adding in REJECTED state
+ # DONT change state on GitHub, because the result
+ # (INGESTED/REJECTED) would be overwritten
+ return TaskState.DONE
+
+ @log_function_entry_exit()
+ def __str__(self):
+ return f"EESSITask(description={self.description}, action={self.action}, state={self.determine_state()})"
diff --git a/scripts/automated_ingestion/eessi_task_action.py b/scripts/automated_ingestion/eessi_task_action.py
new file mode 100644
index 00000000..6f141435
--- /dev/null
+++ b/scripts/automated_ingestion/eessi_task_action.py
@@ -0,0 +1,12 @@
+from enum import Enum, auto
+
+
+class EESSITaskAction(Enum):
+ NOP = auto() # perform no action
+ DELETE = auto() # perform a delete operation
+ ADD = auto() # perform an add operation
+ UPDATE = auto() # perform an update operation
+ UNKNOWN = auto() # unknown action
+
+ def __str__(self):
+ return self.name.lower()
diff --git a/scripts/automated_ingestion/eessi_task_description.py b/scripts/automated_ingestion/eessi_task_description.py
new file mode 100644
index 00000000..5ff4c196
--- /dev/null
+++ b/scripts/automated_ingestion/eessi_task_description.py
@@ -0,0 +1,188 @@
+import json
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, Tuple
+
+from eessi_data_object import EESSIDataAndSignatureObject
+from utils import log_function_entry_exit, log_message, LoggingScope
+from remote_storage import DownloadMode
+
+
+@dataclass
+class EESSITaskDescription:
+ """Class representing an EESSI task to be performed, including its metadata and associated data files."""
+
+ # The EESSI data and signature object associated with this task
+ task_object: EESSIDataAndSignatureObject
+
+ # Whether the signature was successfully verified
+ signature_verified: bool = False
+
+ # Metadata from the task description file
+ metadata: Dict[str, Any] = None
+
+ # task element
+ task: Dict[str, Any] = None
+
+ # source element
+ source: Dict[str, Any] = None
+
+ @log_function_entry_exit()
+ def __init__(self, task_object: EESSIDataAndSignatureObject):
+ """
+ Initialize an EESSITaskDescription object.
+
+ Args:
+ task_object: The EESSI data and signature object associated with this task
+ """
+ self.task_object = task_object
+ self.metadata = {}
+
+ self.task_object.download(mode=DownloadMode.CHECK_REMOTE)
+
+ # Verify signature and set initial state
+ self.signature_verified = self.task_object.verify_signature()
+
+ # Try to read metadata (will only succeed if signature is verified)
+ try:
+ self._read_metadata()
+ except RuntimeError:
+ # Expected if signature is not verified yet
+ pass
+
+ # TODO: Process the task file contents
+ # check if the task file contains a task field and add that to self
+ if 'task' in self.metadata:
+ self.task = self.metadata['task']
+ else:
+ self.task = None
+
+ # check if the task file contains a link2pr field and add that to source element
+ if 'link2pr' in self.metadata:
+ self.source = self.metadata['link2pr']
+ else:
+ self.source = None
+
+ @log_function_entry_exit()
+ def _read_metadata(self) -> None:
+ """
+ Internal method to read and parse the metadata from the task description file.
+ Only reads metadata if the signature has been verified.
+ """
+ if not self.signature_verified:
+ log_message(LoggingScope.ERROR, 'ERROR', "Cannot read metadata: signature not verified for %s",
+ self.task_object.local_file_path)
+ raise RuntimeError("Cannot read metadata: signature not verified")
+
+ try:
+ with open(self.task_object.local_file_path, 'r') as file:
+ self.raw_contents = file.read()
+ self.metadata = json.loads(self.raw_contents)
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Successfully read metadata from %s",
+ self.task_object.local_file_path)
+ except json.JSONDecodeError as err:
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to parse JSON in task description file %s: %s",
+ self.task_object.local_file_path, str(err))
+ raise
+ except Exception as err:
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to read task description file %s: %s",
+ self.task_object.local_file_path, str(err))
+ raise
+
+ @log_function_entry_exit()
+ def get_contents(self) -> str:
+ """
+ Get the contents of the task description / metadata file.
+ """
+ return self.raw_contents
+
+ @log_function_entry_exit()
+ def get_metadata_file_components(self) -> Tuple[str, str, str, str, str, str]:
+ """
+ Get the components of the metadata file name.
+
+ An example of the metadata file name is:
+ eessi-2023.06-software-linux-x86_64-amd-zen2-1745557626.tar.gz.meta.txt
+
+ The components are:
+ eessi: some prefix
+ VERSION: 2023.06
+ COMPONENT: software
+ OS: linux
+ ARCHITECTURE: x86_64-amd-zen2
+ TIMESTAMP: 1745557626
+ SUFFIX: tar.gz.meta.txt
+
+ The ARCHITECTURE component can include one to two hyphens.
+ The SUFFIX is the part after the first dot (no other components should include dots).
+ """
+ # obtain file name from local file path using basename
+ file_name = Path(self.task_object.local_file_path).name
+ # split file_name into part before suffix and the suffix
+ # idea: split on last hyphen, then split on first dot
+ suffix = file_name.split('-')[-1].split('.', 1)[1]
+ file_name_without_suffix = file_name.strip(f".{suffix}")
+ # from file_name_without_suffix determine VERSION (2nd element), COMPONENT (3rd element), OS (4th element),
+ # ARCHITECTURE (5th to second last elements) and TIMESTAMP (last element)
+ components = file_name_without_suffix.split('-')
+ version = components[1]
+ component = components[2]
+ os = components[3]
+ architecture = '-'.join(components[4:-1])
+ timestamp = components[-1]
+ return version, component, os, architecture, timestamp, suffix
+
+ @log_function_entry_exit()
+ def get_metadata_value(self, key: str) -> str:
+ """
+ Get the value of a key from the task description / metadata file.
+ """
+ # check that key is defined and has a length > 0
+ if not key or len(key) == 0:
+ raise ValueError("get_metadata_value: key is not defined or has a length of 0")
+
+ value = None
+ task = self.task
+ source = self.source
+ # check if key is in task or source
+ if task and key in task:
+ value = task[key]
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ f"Value '{value}' for key '{key}' found in information from task metadata: {task}")
+ elif source and key in source:
+ value = source[key]
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ f"Value '{value}' for key '{key}' found in information from source metadata: {source}")
+ else:
+ log_message(LoggingScope.TASK_OPS, 'INFO',
+ f"Value for key '{key}' neither found in task metadata nor source metadata")
+ raise ValueError(f"Value for key '{key}' neither found in task metadata nor source metadata")
+ return value
+
+ @log_function_entry_exit()
+ def get_pr_number(self) -> str:
+ """
+ Get the PR number from the task description / metadata file.
+ """
+ return self.get_metadata_value('pr')
+
+ @log_function_entry_exit()
+ def get_repo_name(self) -> str:
+ """
+ Get the repository name from the task description / metadata file.
+ """
+ return self.get_metadata_value('repo')
+
+ @log_function_entry_exit()
+ def get_task_file_name(self) -> str:
+ """
+ Get the file name from the task description / metadata file.
+ """
+ # get file name from remote file path using basename
+ file_name = Path(self.task_object.remote_file_path).name
+ return file_name
+
+ @log_function_entry_exit()
+ def __str__(self) -> str:
+ """Return a string representation of the EESSITaskDescription object."""
+ return f"EESSITaskDescription({self.task_object.local_file_path}, verified={self.signature_verified})"
diff --git a/scripts/automated_ingestion/eessi_task_payload.py b/scripts/automated_ingestion/eessi_task_payload.py
new file mode 100644
index 00000000..cb39cc81
--- /dev/null
+++ b/scripts/automated_ingestion/eessi_task_payload.py
@@ -0,0 +1,106 @@
+from dataclasses import dataclass
+import tarfile
+from pathlib import PurePosixPath
+import os
+from typing import Dict
+
+from eessi_data_object import EESSIDataAndSignatureObject
+from utils import log_function_entry_exit
+from remote_storage import DownloadMode
+
+
+@dataclass
+class EESSITaskPayload:
+ """Class representing an EESSI task payload (tarball/artifact) and its signature."""
+
+ # The EESSI data and signature object associated with this payload
+ payload_object: EESSIDataAndSignatureObject
+
+ # Whether the signature was successfully verified
+ signature_verified: bool = False
+
+ # possibly at a later point in time, we will add inferred metadata here
+ # such as the prefix in a tarball, the main elements, or which software
+ # package it includes
+
+ @log_function_entry_exit()
+ def __init__(self, payload_object: EESSIDataAndSignatureObject):
+ """
+ Initialize an EESSITaskPayload object.
+
+ Args:
+ payload_object: The EESSI data and signature object associated with this payload
+ """
+ self.payload_object = payload_object
+
+ # Download the payload and its signature
+ self.payload_object.download(mode=DownloadMode.CHECK_REMOTE)
+
+ # Verify signature
+ self.signature_verified = self.payload_object.verify_signature()
+
+ @log_function_entry_exit()
+ def analyse_contents(self, config: Dict) -> str:
+ """Analyse the contents of the payload and return a summary in a ready-to-use HTML format."""
+ tar = tarfile.open(self.payload_object.local_file_path, 'r')
+ members = tar.getmembers()
+ tar_num_members = len(members)
+ paths = sorted([m.path for m in members])
+
+ if tar_num_members < 100:
+ tar_members_desc = "Full listing of the contents of the tarball:"
+ members_list = paths
+
+ else:
+ tar_members_desc = "Summarized overview of the contents of the tarball:"
+ # determine prefix after filtering out '/init' subdirectory,
+ # to get actual prefix for specific CPU target (like '2023.06/software/linux/aarch64/neoverse_v1')
+ init_subdir = os.path.join('*', 'init')
+ non_init_paths = sorted(
+ [path for path in paths if not any(parent.match(init_subdir) for parent in PurePosixPath(path).parents)]
+ )
+ if non_init_paths:
+ prefix = os.path.commonprefix(non_init_paths)
+ else:
+ prefix = os.path.commonprefix(paths)
+
+ # TODO: this only works for software tarballs, how to handle compat layer tarballs?
+ swdirs = [ # all directory names with the pattern: /software//
+ member.path
+ for member in members
+ if member.isdir() and PurePosixPath(member.path).match(os.path.join(prefix, 'software', '*', '*'))
+ ]
+ modfiles = [ # all filenames with the pattern: /modules///*.lua
+ member.path
+ for member in members
+ if member.isfile() and
+ PurePosixPath(member.path).match(os.path.join(prefix, 'modules', '*', '*', '*.lua'))
+ ]
+ other = [ # anything that is not in /software nor /modules
+ member.path
+ for member in members
+ if (not PurePosixPath(prefix).joinpath('software') in PurePosixPath(member.path).parents
+ and not PurePosixPath(prefix).joinpath('modules') in PurePosixPath(member.path).parents)
+ # if not fnmatch.fnmatch(m.path, os.path.join(prefix, 'software', '*'))
+ # and not fnmatch.fnmatch(m.path, os.path.join(prefix, 'modules', '*'))
+ ]
+ members_list = sorted(swdirs + modfiles + other)
+
+ # Construct the overview
+ overview = config['github']['task_summary_payload_overview_template'].format(
+ tar_num_members=tar_num_members,
+ bucket_url=self.payload_object.remote_client.get_bucket_url(),
+ remote_file_path=self.payload_object.remote_file_path,
+ tar_members_desc=tar_members_desc,
+ tar_members='\n'.join(members_list)
+ )
+
+ # Make sure that the overview does not exceed Github's maximum length (65536 characters)
+ if len(overview) > 60000:
+ overview = overview[:60000] + "\n\nWARNING: output exceeded the maximum length and was truncated!\n```"
+ return overview
+
+ @log_function_entry_exit()
+ def __str__(self) -> str:
+ """Return a string representation of the EESSITaskPayload object."""
+ return f"EESSITaskPayload({self.payload_object.local_file_path}, verified={self.signature_verified})"
diff --git a/scripts/automated_ingestion/eessitarball.py b/scripts/automated_ingestion/eessitarball.py
index 40ac6fa1..cc3c4ae4 100644
--- a/scripts/automated_ingestion/eessitarball.py
+++ b/scripts/automated_ingestion/eessitarball.py
@@ -1,11 +1,9 @@
-from utils import send_slack_message, sha256sum
+from utils import send_slack_message, sha256sum, log_function_entry_exit, log_message, LoggingScope
from pathlib import PurePosixPath
-import boto3
import github
import json
-import logging
import os
import subprocess
import tarfile
@@ -19,7 +17,8 @@ class EessiTarball:
for which it interfaces with the S3 bucket, GitHub, and CVMFS.
"""
- def __init__(self, object_name, config, git_staging_repo, s3, bucket, cvmfs_repo):
+ @log_function_entry_exit()
+ def __init__(self, object_name, config, git_staging_repo, s3_bucket, cvmfs_repo):
"""Initialize the tarball object."""
self.config = config
self.git_repo = git_staging_repo
@@ -27,15 +26,14 @@ def __init__(self, object_name, config, git_staging_repo, s3, bucket, cvmfs_repo
self.metadata_sig_file = self.metadata_file + config['signatures']['signature_file_extension']
self.object = object_name
self.object_sig = object_name + config['signatures']['signature_file_extension']
- self.s3 = s3
- self.bucket = bucket
+ self.s3_bucket = s3_bucket
self.cvmfs_repo = cvmfs_repo
self.local_path = os.path.join(config['paths']['download_dir'], os.path.basename(object_name))
self.local_sig_path = self.local_path + config['signatures']['signature_file_extension']
self.local_metadata_path = self.local_path + config['paths']['metadata_file_extension']
self.local_metadata_sig_path = self.local_metadata_path + config['signatures']['signature_file_extension']
self.sig_verified = False
- self.url = f'https://{bucket}.s3.amazonaws.com/{object_name}'
+ self.url = f'https://{s3_bucket.bucket}.s3.amazonaws.com/{object_name}'
self.states = {
'new': {'handler': self.mark_new_tarball_as_staged, 'next_state': 'staged'},
@@ -49,6 +47,7 @@ def __init__(self, object_name, config, git_staging_repo, s3, bucket, cvmfs_repo
# Find the initial state of this tarball.
self.state = self.find_state()
+ @log_function_entry_exit()
def download(self, force=False):
"""
Download this tarball and its corresponding metadata file, if this hasn't been already done.
@@ -57,32 +56,41 @@ def download(self, force=False):
(self.object, self.local_path, self.object_sig, self.local_sig_path),
(self.metadata_file, self.local_metadata_path, self.metadata_sig_file, self.local_metadata_sig_path),
]
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Downloading %s", files)
skip = False
for (object, local_file, sig_object, local_sig_file) in files:
if force or not os.path.exists(local_file):
# First we try to download signature file, which may or may not be available
# and may be optional or required.
try:
- self.s3.download_file(self.bucket, sig_object, local_sig_file)
- except:
+ log_msg = "Downloading signature file %s to %s"
+ log_message(LoggingScope.DOWNLOAD, 'INFO', log_msg, sig_object, local_sig_file)
+ self.s3_bucket.download_file(self.s3_bucket.bucket, sig_object, local_sig_file)
+ except Exception as err:
+ log_msg = 'Failed to download signature file %s for %s from %s to %s.'
if self.config['signatures'].getboolean('signatures_required', True):
- logging.error(
- f'Failed to download signature file {sig_object} for {object} from {self.bucket} to {local_sig_file}.'
+ log_msg += '\nException: %s'
+ log_message(
+ LoggingScope.ERROR, 'ERROR', log_msg,
+ sig_object, object, self.s3_bucket.bucket, local_sig_file, err
)
skip = True
break
else:
- logging.warning(
- f'Failed to download signature file {sig_object} for {object} from {self.bucket} to {local_sig_file}. ' +
- 'Ignoring this, because signatures are not required with the current configuration.'
+ log_msg += ' Ignoring this, because signatures are not required'
+ log_msg += ' with the current configuration.'
+ log_msg += '\nException: %s'
+ log_message(
+ LoggingScope.DOWNLOAD, 'WARNING', log_msg,
+ sig_object, object, self.s3_bucket.bucket, local_sig_file, err
)
# Now we download the file itself.
try:
- self.s3.download_file(self.bucket, object, local_file)
- except:
- logging.error(
- f'Failed to download {object} from {self.bucket} to {local_file}.'
- )
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Downloading file %s to %s", object, local_file)
+ self.s3_bucket.download_file(self.s3_bucket.bucket, object, local_file)
+ except Exception as err:
+ log_msg = 'Failed to download %s from %s to %s.\nException: %s'
+ log_message(LoggingScope.ERROR, 'ERROR', log_msg, object, self.s3_bucket.bucket, local_file, err)
skip = True
break
# If any required download failed, make sure to skip this tarball completely.
@@ -90,27 +98,30 @@ def download(self, force=False):
self.local_path = None
self.local_metadata_path = None
+ @log_function_entry_exit()
def find_state(self):
"""Find the state of this tarball by searching through the state directories in the git repository."""
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Find state for %s", self.object)
for state in list(self.states.keys()):
- # iterate through the state dirs and try to find the tarball's metadata file
try:
self.git_repo.get_contents(state + '/' + self.metadata_file)
+ log_msg = "Found metadata file %s in state: %s"
+ log_message(LoggingScope.STATE_OPS, 'INFO', log_msg, self.metadata_file, state)
return state
except github.UnknownObjectException:
# no metadata file found in this state's directory, so keep searching...
continue
- except github.GithubException as e:
- if e.status == 404:
+ except github.GithubException as err:
+ if err.status == 404:
# no metadata file found in this state's directory, so keep searching...
continue
else:
# if there was some other (e.g. connection) issue, abort the search for this tarball
- logging.warning(f'Unable to determine the state of {self.object}, the GitHub API returned status {e.status}!')
+ log_msg = 'Unable to determine the state of %s, the GitHub API returned status %s!'
+ log_message(LoggingScope.ERROR, 'WARNING', log_msg, self.object, err.status)
return "unknown"
- else:
- # if no state was found, we assume this is a new tarball that was ingested to the bucket
- return "new"
+ log_message(LoggingScope.STATE_OPS, 'INFO', "Tarball %s is new", self.metadata_file)
+ return "new"
def get_contents_overview(self):
"""Return an overview of what is included in the tarball."""
@@ -128,7 +139,9 @@ def get_contents_overview(self):
# determine prefix after filtering out '/init' subdirectory,
# to get actual prefix for specific CPU target (like '2023.06/software/linux/aarch64/neoverse_v1')
init_subdir = os.path.join('*', 'init')
- non_init_paths = sorted([p for p in paths if not any(x.match(init_subdir) for x in PurePosixPath(p).parents)])
+ non_init_paths = sorted(
+ [p for p in paths if not any(x.match(init_subdir) for x in PurePosixPath(p).parents)]
+ )
if non_init_paths:
prefix = os.path.commonprefix(non_init_paths)
else:
@@ -148,8 +161,8 @@ def get_contents_overview(self):
other = [ # anything that is not in /software nor /modules
m.path
for m in members
- if not PurePosixPath(prefix).joinpath('software') in PurePosixPath(m.path).parents
- and not PurePosixPath(prefix).joinpath('modules') in PurePosixPath(m.path).parents
+ if (not PurePosixPath(prefix).joinpath('software') in PurePosixPath(m.path).parents
+ and not PurePosixPath(prefix).joinpath('modules') in PurePosixPath(m.path).parents)
# if not fnmatch.fnmatch(m.path, os.path.join(prefix, 'software', '*'))
# and not fnmatch.fnmatch(m.path, os.path.join(prefix, 'modules', '*'))
]
@@ -181,88 +194,121 @@ def run_handler(self):
handler = self.states[self.state]['handler']
handler()
+ def to_string(self, oneline=False):
+ """Serialize tarball info so it can be printed."""
+ str = f"tarball: {self.object}"
+ sep = "\n" if not oneline else ","
+ str += f"{sep} metadt: {self.metadata_file}"
+ str += f"{sep} bucket: {self.s3_bucket.bucket}"
+ str += f"{sep} cvmfs.: {self.cvmfs_repo}"
+ str += f"{sep} GHrepo: {self.git_repo}"
+ return str
+
+ @log_function_entry_exit()
def verify_signatures(self):
- """Verify the signatures of the downloaded tarball and metadata file using the corresponding signature files."""
-
+ """
+ Verify the signatures of the downloaded tarball and metadata file
+ using the corresponding signature files.
+ """
sig_missing_msg = 'Signature file %s is missing.'
sig_missing = False
for sig_file in [self.local_sig_path, self.local_metadata_sig_path]:
if not os.path.exists(sig_file):
- logging.warning(sig_missing_msg % sig_file)
+ log_message(LoggingScope.VERIFICATION, 'WARNING', sig_missing_msg, sig_file)
sig_missing = True
+ log_message(LoggingScope.VERIFICATION, 'INFO', "Signature file %s is missing.", sig_file)
if sig_missing:
# If signature files are missing, we return a failure,
# unless the configuration specifies that signatures are not required.
if self.config['signatures'].getboolean('signatures_required', True):
+ log_message(LoggingScope.ERROR, 'ERROR', "Signature file %s is missing.", sig_file)
return False
else:
+ log_msg = "Signature file %s is missing, but signatures are not required."
+ log_message(LoggingScope.VERIFICATION, 'INFO', log_msg, sig_file)
return True
# If signatures are provided, we should always verify them, regardless of the signatures_required.
# In order to do so, we need the verification script and an allowed signers file.
+ verify_runenv = self.config['signatures']['signature_verification_runenv'].split()
verify_script = self.config['signatures']['signature_verification_script']
allowed_signers_file = self.config['signatures']['allowed_signers_file']
if not os.path.exists(verify_script):
- logging.error(f'Unable to verify signatures, the specified signature verification script does not exist!')
+ log_msg = 'Unable to verify signatures, the specified signature verification script does not exist!'
+ log_message(LoggingScope.ERROR, 'ERROR', log_msg)
return False
if not os.path.exists(allowed_signers_file):
- logging.error(f'Unable to verify signatures, the specified allowed signers file does not exist!')
+ log_msg = 'Unable to verify signatures, the specified allowed signers file does not exist!'
+ log_message(LoggingScope.ERROR, 'ERROR', log_msg)
return False
- for (file, sig_file) in [(self.local_path, self.local_sig_path), (self.local_metadata_path, self.local_metadata_sig_path)]:
+ for (file, sig_file) in [
+ (self.local_path, self.local_sig_path),
+ (self.local_metadata_path, self.local_metadata_sig_path)
+ ]:
+ command = verify_runenv + [verify_script, '--verify', '--allowed-signers-file', allowed_signers_file,
+ '--file', file, '--signature-file', sig_file]
+ log_message(LoggingScope.VERIFICATION, 'INFO', "Running command: %s", ' '.join(command))
+
verify_cmd = subprocess.run(
- [verify_script, '--verify', '--allowed-signers-file', allowed_signers_file, '--file', file, '--signature-file', sig_file],
+ command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
if verify_cmd.returncode == 0:
- logging.debug(f'Signature for {file} successfully verified.')
+ log_message(LoggingScope.VERIFICATION, 'DEBUG', 'Signature for %s successfully verified.', file)
else:
- logging.error(f'Failed to verify signature for {file}.')
+ log_message(LoggingScope.ERROR, 'ERROR', 'Failed to verify signature for %s.', file)
+ log_message(LoggingScope.ERROR, 'ERROR', " stdout: %s", verify_cmd.stdout.decode('UTF-8'))
+ log_message(LoggingScope.ERROR, 'ERROR', " stderr: %s", verify_cmd.stderr.decode('UTF-8'))
return False
self.sig_verified = True
return True
+ @log_function_entry_exit()
def verify_checksum(self):
"""Verify the checksum of the downloaded tarball with the one in its metadata file."""
local_sha256 = sha256sum(self.local_path)
meta_sha256 = None
with open(self.local_metadata_path, 'r') as meta:
meta_sha256 = json.load(meta)['payload']['sha256sum']
- logging.debug(f'Checksum of downloaded tarball: {local_sha256}')
- logging.debug(f'Checksum stored in metadata file: {meta_sha256}')
+ log_message(LoggingScope.VERIFICATION, 'DEBUG', 'Checksum of downloaded tarball: %s', local_sha256)
+ log_message(LoggingScope.VERIFICATION, 'DEBUG', 'Checksum stored in metadata file: %s', meta_sha256)
return local_sha256 == meta_sha256
+ @log_function_entry_exit()
def ingest(self):
"""Process a tarball that is ready to be ingested by running the ingestion script."""
- #TODO: check if there is an open issue for this tarball, and if there is, skip it.
- logging.info(f'Tarball {self.object} is ready to be ingested.')
+ # TODO: check if there is an open issue for this tarball, and if there is, skip it.
+ log_message(LoggingScope.STATE_OPS, 'INFO', 'Tarball %s is ready to be ingested.', self.object)
self.download()
- logging.info('Verifying its signature...')
+ log_message(LoggingScope.VERIFICATION, 'INFO', 'Verifying its signature...')
if not self.verify_signatures():
issue_msg = f'Failed to verify signatures for `{self.object}`'
- logging.error(issue_msg)
+ log_message(LoggingScope.ERROR, 'ERROR', issue_msg)
if not self.issue_exists(issue_msg, state='open'):
self.git_repo.create_issue(title=issue_msg, body=issue_msg)
return
else:
- logging.debug(f'Signatures of {self.object} and its metadata file successfully verified.')
+ log_msg = 'Signatures of %s and its metadata file successfully verified.'
+ log_message(LoggingScope.VERIFICATION, 'DEBUG', log_msg, self.object)
- logging.info('Verifying its checksum...')
+ log_message(LoggingScope.VERIFICATION, 'INFO', 'Verifying its checksum...')
if not self.verify_checksum():
issue_msg = f'Failed to verify checksum for `{self.object}`'
- logging.error(issue_msg)
+ log_message(LoggingScope.ERROR, 'ERROR', issue_msg)
if not self.issue_exists(issue_msg, state='open'):
self.git_repo.create_issue(title=issue_msg, body=issue_msg)
return
else:
- logging.debug(f'Checksum of {self.object} matches the one in its metadata file.')
+ log_msg = 'Checksum of %s matches the one in its metadata file.'
+ log_message(LoggingScope.VERIFICATION, 'DEBUG', log_msg, self.object)
script = self.config['paths']['ingestion_script']
sudo = ['sudo'] if self.config['cvmfs'].getboolean('ingest_as_root', True) else []
- logging.info(f'Running the ingestion script for {self.object}...')
+ log_message(LoggingScope.STATE_OPS, 'INFO', 'Running the ingestion script for %s...', self.object)
ingest_cmd = subprocess.run(
sudo + [script, self.cvmfs_repo, self.local_path],
stdout=subprocess.PIPE,
@@ -273,7 +319,9 @@ def ingest(self):
if self.config.has_section('slack') and self.config['slack'].getboolean('ingestion_notification', False):
send_slack_message(
self.config['secrets']['slack_webhook'],
- self.config['slack']['ingestion_message'].format(tarball=os.path.basename(self.object), cvmfs_repo=self.cvmfs_repo)
+ self.config['slack']['ingestion_message'].format(
+ tarball=os.path.basename(self.object),
+ cvmfs_repo=self.cvmfs_repo)
)
else:
issue_title = f'Failed to ingest {self.object}'
@@ -285,133 +333,262 @@ def ingest(self):
stderr=ingest_cmd.stderr.decode('UTF-8'),
)
if self.issue_exists(issue_title, state='open'):
- logging.info(f'Failed to ingest {self.object}, but an open issue already exists, skipping...')
+ log_msg = 'Failed to ingest %s, but an open issue already exists, skipping...'
+ log_message(LoggingScope.STATE_OPS, 'INFO', log_msg, self.object)
else:
self.git_repo.create_issue(title=issue_title, body=issue_body)
def print_ingested(self):
"""Process a tarball that has already been ingested."""
- logging.info(f'{self.object} has already been ingested, skipping...')
+ log_message(LoggingScope.STATE_OPS, 'INFO', '%s has already been ingested, skipping...', self.object)
- def mark_new_tarball_as_staged(self):
+ @log_function_entry_exit()
+ def mark_new_tarball_as_staged(self, branch=None):
"""Process a new tarball that was added to the staging bucket."""
next_state = self.next_state(self.state)
- logging.info(f'Found new tarball {self.object}, downloading it...')
+ log_msg = 'Found new tarball %s, downloading it...'
+ log_message(LoggingScope.STATE_OPS, 'INFO', log_msg, self.object)
# Download the tarball and its metadata file.
# Use force as it may be a new attempt for an existing tarball that failed before.
self.download(force=True)
if not self.local_path or not self.local_metadata_path:
- logging.warn('Skipping this tarball...')
+ log_msg = "Skipping tarball %s - download failed"
+ log_message(LoggingScope.STATE_OPS, 'WARNING', log_msg, self.object)
return
# Verify the signatures of the tarball and metadata file.
if not self.verify_signatures():
- logging.warn('Signature verification of the tarball or its metadata failed, skipping this tarball...')
+ log_msg = "Skipping tarball %s - signature verification failed"
+ log_message(LoggingScope.STATE_OPS, 'WARNING', log_msg, self.object)
+ return
+ # If no branch is provided, use the main branch
+ target_branch = branch if branch else 'main'
+ log_msg = "Adding metadata to '%s' folder in %s branch"
+ log_message(LoggingScope.STATE_OPS, 'INFO', log_msg, next_state, target_branch)
+
+ file_path_staged = next_state + '/' + self.metadata_file
contents = ''
with open(self.local_metadata_path, 'r') as meta:
contents = meta.read()
-
- logging.info(f'Adding tarball\'s metadata to the "{next_state}" folder of the git repository.')
- file_path_staged = next_state + '/' + self.metadata_file
- new_file = self.git_repo.create_file(file_path_staged, 'new tarball', contents, branch='main')
+ self.git_repo.create_file(file_path_staged, 'new tarball', contents, branch=target_branch)
self.state = next_state
- self.run_handler()
+ if not branch: # Only run handler if we're not part of a group
+ self.run_handler()
def print_rejected(self):
"""Process a (rejected) tarball for which the corresponding PR has been closed witout merging."""
- logging.info("This tarball was rejected, so we're skipping it.")
+ log_message(LoggingScope.STATE_OPS, 'INFO', "This tarball was rejected, so we're skipping it.")
# Do we want to delete rejected tarballs at some point?
def print_unknown(self):
"""Process a tarball which has an unknown state."""
- logging.info("The state of this tarball could not be determined, so we're skipping it.")
+ log_msg = "The state of this tarball could not be determined,"
+ log_msg += " so we're skipping it."
+ log_message(LoggingScope.STATE_OPS, 'INFO', log_msg)
+
+ def find_next_sequence_number(self, repo, pr_id):
+ """Find the next available sequence number for staging PRs of a source PR."""
+ # Search for existing branches for this source PR
+ base_branch = f'staging-{repo.replace("/", "-")}-pr-{pr_id}-seq-'
+ existing_branches = [
+ ref.ref for ref in self.git_repo.get_git_refs()
+ if ref.ref.startswith(f'refs/heads/{base_branch}')
+ ]
+
+ if not existing_branches:
+ return 1
+
+ # Extract sequence numbers from existing branches
+ sequence_numbers = []
+ for branch in existing_branches:
+ try:
+ # Extract the sequence number from branch name
+ # Format: staging--pr--seq-
+ sequence = int(branch.split('-')[-1])
+ sequence_numbers.append(sequence)
+ except (ValueError, IndexError):
+ continue
+
+ if not sequence_numbers:
+ return 1
- def make_approval_request(self):
+ # Return next available sequence number
+ return max(sequence_numbers) + 1
+
+ @log_function_entry_exit()
+ def make_approval_request(self, tarballs_in_group=None):
"""Process a staged tarball by opening a pull request for ingestion approval."""
next_state = self.next_state(self.state)
- file_path_staged = self.state + '/' + self.metadata_file
- file_path_to_ingest = next_state + '/' + self.metadata_file
+ log_msg = "Making approval request for tarball %s in state %s to %s"
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, self.object, self.state, next_state)
+ # obtain link2pr information (repo and pr_id) from metadata file
+ with open(self.local_metadata_path, 'r') as meta:
+ metadata = meta.read()
+ meta_dict = json.loads(metadata)
+ repo, pr_id = meta_dict['link2pr']['repo'], meta_dict['link2pr']['pr']
- filename = os.path.basename(self.object)
- tarball_metadata = self.git_repo.get_contents(file_path_staged)
- git_branch = filename + '_' + next_state
- self.download()
+ # find next sequence number for staging PRs of this source PR
+ sequence = self.find_next_sequence_number(repo, pr_id)
+ git_branch = f'staging-{repo.replace("/", "-")}-pr-{pr_id}-seq-{sequence}'
+ # Check if git_branch exists and what the status of the corressponding PR is
main_branch = self.git_repo.get_branch('main')
if git_branch in [branch.name for branch in self.git_repo.get_branches()]:
- # Existing branch found for this tarball, so we've run this step before.
- # Try to find out if there's already a PR as well...
- logging.info("Branch already exists for " + self.object)
- # Filtering with only head= returns all prs if there's no match, so double-check
- find_pr = [pr for pr in self.git_repo.get_pulls(head=git_branch, state='all') if pr.head.ref == git_branch]
- logging.debug('Found PRs: ' + str(find_pr))
+ log_msg = "Branch %s already exists, checking the status of the corresponding PR..."
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, git_branch)
+ find_pr = [pr for pr in self.git_repo.get_pulls(head=git_branch, state='all')
+ if pr.head.ref == git_branch]
if find_pr:
- # So, we have a branch and a PR for this tarball (if there are more, pick the first one)...
pr = find_pr.pop(0)
- logging.info(f'PR {pr.number} found for {self.object}')
if pr.state == 'open':
- # The PR is still open, so it hasn't been reviewed yet: ignore this tarball.
- logging.info('PR is still open, skipping this tarball...')
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', 'PR is still open, skipping this tarball...')
return
elif pr.state == 'closed' and not pr.merged:
- # The PR was closed but not merged, i.e. it was rejected for ingestion.
- logging.info('PR was rejected')
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', 'PR was rejected')
self.reject()
return
else:
- logging.warn(f'Warning, tarball {self.object} is in a weird state:')
- logging.warn(f'Branch: {git_branch}\nPR: {pr}\nPR state: {pr.state}\nPR merged: {pr.merged}')
+ log_msg = 'Warning, tarball %s is in a weird state:'
+ log_message(LoggingScope.GITHUB_OPS, 'WARNING', log_msg, self.object)
+ log_msg = 'Branch: %s\nPR: %s\nPR state: %s\nPR merged: %s'
+ log_message(LoggingScope.GITHUB_OPS, 'WARNING', log_msg,
+ git_branch, pr, pr.state, pr.merged)
+ # TODO: should we delete the branch or open an issue?
+ return
else:
- # There is a branch, but no PR for this tarball.
- # This is weird, so let's remove the branch and reprocess the tarball.
- logging.info(f'Tarball {self.object} has a branch, but no PR.')
- logging.info(f'Removing existing branch...')
+ log_msg = 'Tarball %s has a branch, but no PR.'
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, self.object)
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', 'Removing existing branch...')
ref = self.git_repo.get_git_ref(f'heads/{git_branch}')
ref.delete()
- logging.info(f'Making pull request to get ingestion approval for {self.object}.')
- # Create a new branch
+
+ # Create new branch
self.git_repo.create_git_ref(ref='refs/heads/' + git_branch, sha=main_branch.commit.sha)
- # Move the file to the directory of the next stage in this branch
- self.move_metadata_file(self.state, next_state, branch=git_branch)
- # Get metadata file contents
- metadata = ''
- with open(self.local_metadata_path, 'r') as meta:
- metadata = meta.read()
- meta_dict = json.loads(metadata)
- repo, pr_id = meta_dict['link2pr']['repo'], meta_dict['link2pr']['pr']
- pr_url = f"https://github.com/{repo}/pull/{pr_id}"
- # Try to get the tarball contents and open a PR to get approval for the ingestion
+
+ # Move metadata file(s) to approved directory
+ log_msg = "Moving metadata for %s from %s to %s in branch %s"
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg,
+ self.object, self.state, next_state, git_branch)
+ if tarballs_in_group is None:
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', "Moving metadata for individual tarball to staged")
+ self.move_metadata_file(self.state, next_state, branch=git_branch)
+ else:
+ log_msg = "Moving metadata for %d tarballs to staged"
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, len(tarballs_in_group))
+ for tarball in tarballs_in_group:
+ temp_tar = EessiTarball(tarball, self.config, self.git_repo, self.s3_bucket, self.cvmfs_repo)
+ temp_tar.move_metadata_file(self.state, next_state, branch=git_branch)
+
+ # Create PR with appropriate template
try:
- tarball_contents = self.get_contents_overview()
- pr_body = self.config['github']['pr_body'].format(
- cvmfs_repo=self.cvmfs_repo,
- pr_url=pr_url,
- tar_overview=self.get_contents_overview(),
- metadata=metadata,
- )
- pr_title = '[%s] Ingest %s' % (self.cvmfs_repo, filename)
+ pr_url = f"https://github.com/{repo}/pull/{pr_id}"
+ if tarballs_in_group is None:
+ log_msg = "Creating PR for individual tarball: %s"
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', log_msg, self.object)
+ tarball_contents = self.get_contents_overview()
+ pr_body = self.config['github']['individual_pr_body'].format(
+ cvmfs_repo=self.cvmfs_repo,
+ pr_url=pr_url,
+ tar_overview=tarball_contents,
+ metadata=metadata,
+ )
+ pr_title = f'[{self.cvmfs_repo}] Ingest {os.path.basename(self.object)}'
+ else:
+ # Group of tarballs
+ tar_overviews = []
+ for tarball in tarballs_in_group:
+ try:
+ temp_tar = EessiTarball(tarball, self.config, self.git_repo, self.s3_bucket, self.cvmfs_repo)
+ temp_tar.download()
+ overview = temp_tar.get_contents_overview()
+ tar_details_tpl = "\nContents of %s
\n\n%s\n \n"
+ tar_overviews.append(tar_details_tpl % (tarball, overview))
+ except Exception as err:
+ log_msg = "Failed to get contents overview for %s: %s"
+ log_message(LoggingScope.ERROR, 'ERROR', log_msg, tarball, err)
+ tar_details_tpl = "\nContents of %s
\n\n"
+ tar_details_tpl += "Failed to get contents overview: %s\n \n"
+ tar_overviews.append(tar_details_tpl % (tarball, err))
+
+ pr_body = self.config['github']['grouped_pr_body'].format(
+ cvmfs_repo=self.cvmfs_repo,
+ pr_url=pr_url,
+ tarballs=self.format_tarball_list(tarballs_in_group),
+ metadata=self.format_metadata_list(tarballs_in_group),
+ tar_overview="\n".join(tar_overviews)
+ )
+ pr_title = f'[{self.cvmfs_repo}] Staging PR #{sequence} for {repo}#{pr_id}'
+
+ # Add signature verification status if applicable
if self.sig_verified:
- pr_body += "\n\n:heavy_check_mark: :closed_lock_with_key: The signature of this tarball has been successfully verified."
+ pr_body += "\n\n:heavy_check_mark: :closed_lock_with_key: "
+ pr_body += "The signature of this tarball has been successfully verified."
pr_title += ' :closed_lock_with_key:'
+
self.git_repo.create_pull(title=pr_title, body=pr_body, head=git_branch, base='main')
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', "Created PR: %s", pr_title)
+
except Exception as err:
- issue_title = f'Failed to get contents of {self.object}'
- issue_body = self.config['github']['failed_tarball_overview_issue_body'].format(
- tarball=self.object,
- error=err
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to create PR: %s", err)
+ if not self.issue_exists(f'Failed to get contents of {self.object}', state='open'):
+ self.git_repo.create_issue(
+ title=f'Failed to get contents of {self.object}',
+ body=self.config['github']['failed_tarball_overview_issue_body'].format(
+ tarball=self.object,
+ error=err
+ )
+ )
+
+ def format_tarball_list(self, tarballs):
+ """Format a list of tarballs with checkboxes for approval."""
+ formatted = "### Tarballs to be ingested\n\n"
+ for tarball in tarballs:
+ formatted += f"- [ ] {tarball}\n"
+ return formatted
+
+ def format_metadata_list(self, tarballs):
+ """Format metadata for all tarballs in collapsible sections."""
+ formatted = "### Metadata\n\n"
+ for tarball in tarballs:
+ with open(self.get_metadata_path(tarball), 'r') as meta:
+ metadata = meta.read()
+ formatted += (
+ f"\nMetadata for {tarball}
\n\n"
+ f"```\n{metadata}\n```\n \n\n"
+ )
+ return formatted
+
+ def get_metadata_path(self, tarball=None):
+ """
+ Return the local path of the metadata file.
+
+ Args:
+ tarball (str, optional): Name of the tarball to get metadata path for.
+ If None, use the current tarball's metadata file.
+ """
+ if tarball is None:
+ # For single tarball, use the instance's metadata file
+ if not self.local_metadata_path:
+ self.local_metadata_path = os.path.join(
+ self.config['paths']['download_dir'],
+ os.path.basename(self.metadata_file)
+ )
+ return self.local_metadata_path
+ else:
+ # For group of tarballs, construct path from tarball name
+ return os.path.join(
+ self.config['paths']['download_dir'],
+ os.path.basename(tarball) + self.config['paths']['metadata_file_extension']
)
- if len([i for i in self.git_repo.get_issues(state='open') if i.title == issue_title]) == 0:
- self.git_repo.create_issue(title=issue_title, body=issue_body)
- else:
- logging.info(f'Failed to create tarball overview, but an issue already exists.')
def move_metadata_file(self, old_state, new_state, branch='main'):
"""Move the metadata file of a tarball from an old state's directory to a new state's directory."""
file_path_old = old_state + '/' + self.metadata_file
file_path_new = new_state + '/' + self.metadata_file
- logging.debug(f'Moving metadata file {self.metadata_file} from {file_path_old} to {file_path_new}.')
+ log_message(LoggingScope.GITHUB_OPS, 'INFO', 'Moving metadata file %s from %s to %s in branch %s',
+ self.metadata_file, file_path_old, file_path_new, branch)
tarball_metadata = self.git_repo.get_contents(file_path_old)
# Remove the metadata file from the old state's directory...
self.git_repo.delete_file(file_path_old, 'remove from ' + old_state, sha=tarball_metadata.sha, branch=branch)
@@ -419,10 +596,54 @@ def move_metadata_file(self, old_state, new_state, branch='main'):
self.git_repo.create_file(file_path_new, 'move to ' + new_state, tarball_metadata.decoded_content,
branch=branch)
+ def process_pr_merge(self, pr_number):
+ """Process a merged PR by handling the checkboxes and moving tarballs to appropriate states."""
+ pr = self.git_repo.get_pull(pr_number)
+
+ # Get the branch name
+ branch_name = pr.head.ref
+
+ # Get the list of tarballs from the PR body
+ tarballs = self.extract_tarballs_from_pr_body(pr.body)
+
+ # Get the checked status for each tarball
+ checked_tarballs = self.extract_checked_tarballs(pr.body)
+
+ # Process each tarball
+ for tarball in tarballs:
+ if tarball in checked_tarballs:
+ # Move to approved state
+ self.move_metadata_file('staged', 'approved', branch=branch_name)
+ else:
+ # Move to rejected state
+ self.move_metadata_file('staged', 'rejected', branch=branch_name)
+
+ # Delete the branch after processing
+ ref = self.git_repo.get_git_ref(f'heads/{branch_name}')
+ ref.delete()
+
+ def extract_checked_tarballs(self, pr_body):
+ """Extract list of checked tarballs from PR body."""
+ checked_tarballs = []
+ for line in pr_body.split('\n'):
+ if line.strip().startswith('- [x] '):
+ tarball = line.strip()[6:] # Remove '- [x] ' prefix
+ checked_tarballs.append(tarball)
+ return checked_tarballs
+
+ def extract_tarballs_from_pr_body(self, pr_body):
+ """Extract list of all tarballs from PR body."""
+ tarballs = []
+ for line in pr_body.split('\n'):
+ if line.strip().startswith('- ['):
+ tarball = line.strip()[6:] # Remove '- [ ] ' or '- [x] ' prefix
+ tarballs.append(tarball)
+ return tarballs
+
def reject(self):
"""Reject a tarball for ingestion."""
# Let's move the the tarball to the directory for rejected tarballs.
- logging.info(f'Marking tarball {self.object} as rejected...')
+ log_message(LoggingScope.STATE_OPS, 'INFO', 'Marking tarball %s as rejected...', self.object)
next_state = 'rejected'
self.move_metadata_file(self.state, next_state)
@@ -434,3 +655,82 @@ def issue_exists(self, title, state='open'):
return True
else:
return False
+
+ def get_link2pr_info(self):
+ """Get the link2pr information from the metadata file."""
+ with open(self.local_metadata_path, 'r') as meta:
+ metadata = json.load(meta)
+ return metadata['link2pr']['repo'], metadata['link2pr']['pr']
+
+
+class EessiTarballGroup:
+ """Class to handle a group of tarballs that share the same link2pr information."""
+
+ def __init__(self, first_tarball, config, git_staging_repo, s3_bucket, cvmfs_repo):
+ """Initialize with the first tarball in the group."""
+ self.first_tar = EessiTarball(first_tarball, config, git_staging_repo, s3_bucket, cvmfs_repo)
+ self.config = config
+ self.git_repo = git_staging_repo
+ self.s3_bucket = s3_bucket
+ self.cvmfs_repo = cvmfs_repo
+
+ def download_tarballs_and_more(self, tarballs):
+ """Download all files associated with this group of tarballs."""
+ for tarball in tarballs:
+ temp_tar = EessiTarball(tarball, self.config, self.git_repo, self.s3_bucket, self.cvmfs_repo)
+ log_message(LoggingScope.GROUP_OPS, 'INFO', "downloading files for '%s'", temp_tar.object)
+ temp_tar.download(force=True)
+ if not temp_tar.local_path or not temp_tar.local_metadata_path:
+ log_message(LoggingScope.GROUP_OPS, 'WARNING', "Skipping this tarball: %s", temp_tar.object)
+ return False
+ return True
+
+ def process_group(self, tarballs):
+ """Process a group of tarballs together."""
+ log_message(LoggingScope.GROUP_OPS, 'INFO', "Processing group of %d tarballs", len(tarballs))
+
+ if not self.download_tarballs_and_more(tarballs):
+ log_msg = "Downloading tarballs, metadata files and/or their signatures failed"
+ log_message(LoggingScope.ERROR, 'ERROR', log_msg)
+ return
+
+ # Verify all tarballs have the same link2pr info
+ if not self.verify_group_consistency(tarballs):
+ log_message(LoggingScope.ERROR, 'ERROR', "Tarballs have inconsistent link2pr information")
+ return
+
+ # Mark all tarballs as staged in the group branch, however need to handle first tarball differently
+ log_msg = "Processing first tarball in group: %s"
+ log_message(LoggingScope.GROUP_OPS, 'INFO', log_msg, self.first_tar.object)
+ self.first_tar.mark_new_tarball_as_staged('main') # this sets the state of the first tarball to 'staged'
+ for tarball in tarballs[1:]:
+ log_msg = "Processing tarball in group: %s"
+ log_message(LoggingScope.GROUP_OPS, 'INFO', log_msg, tarball)
+ temp_tar = EessiTarball(tarball, self.config, self.git_repo, self.s3_bucket, self.cvmfs_repo)
+ temp_tar.mark_new_tarball_as_staged('main')
+
+ # Process the group for approval, only works correctly if first tarball is already in state 'staged'
+ self.first_tar.make_approval_request(tarballs)
+
+ def to_string(self, oneline=False):
+ """Serialize tarball group info so it can be printed."""
+ str = f"first tarball: {self.first_tar.to_string(oneline)}"
+ sep = "\n" if not oneline else ","
+ str += f"{sep} config: {self.config}"
+ str += f"{sep} GHrepo: {self.git_repo}"
+ str += f"{sep} s3....: {self.s3_bucket}"
+ str += f"{sep} bucket: {self.s3_bucket.bucket}"
+ str += f"{sep} cvmfs.: {self.cvmfs_repo}"
+ return str
+
+ def verify_group_consistency(self, tarballs):
+ """Verify all tarballs in the group have the same link2pr information."""
+ first_repo, first_pr = self.first_tar.get_link2pr_info()
+
+ for tarball in tarballs[1:]: # Skip first tarball as we already have its info
+ temp_tar = EessiTarball(tarball, self.config, self.git_repo, self.s3_bucket, self.cvmfs_repo)
+ log_message(LoggingScope.DEBUG, 'DEBUG', "temp tar: %s", temp_tar.to_string())
+ repo, pr = temp_tar.get_link2pr_info()
+ if repo != first_repo or pr != first_pr:
+ return False
+ return True
diff --git a/scripts/automated_ingestion/remote_storage.py b/scripts/automated_ingestion/remote_storage.py
new file mode 100644
index 00000000..2a386a7d
--- /dev/null
+++ b/scripts/automated_ingestion/remote_storage.py
@@ -0,0 +1,34 @@
+from enum import Enum
+from typing import Protocol, runtime_checkable
+
+
+class DownloadMode(Enum):
+ """Enum defining different modes for downloading files."""
+ FORCE = 'force' # Always download and overwrite
+ CHECK_REMOTE = 'check-remote' # Download if remote files have changed
+ CHECK_LOCAL = 'check-local' # Download if files don't exist locally (default)
+
+
+@runtime_checkable
+class RemoteStorageClient(Protocol):
+ """Protocol defining the interface for remote storage clients."""
+
+ def get_metadata(self, remote_path: str) -> dict:
+ """Get metadata about a remote object.
+
+ Args:
+ remote_path: Path to the object in remote storage
+
+ Returns:
+ Dictionary containing object metadata, including 'ETag' key
+ """
+ ...
+
+ def download(self, remote_path: str, local_path: str) -> None:
+ """Download a remote file to a local location.
+
+ Args:
+ remote_path: Path to the object in remote storage
+ local_path: Local path where to save the file
+ """
+ ...
diff --git a/scripts/automated_ingestion/s3_bucket.py b/scripts/automated_ingestion/s3_bucket.py
new file mode 100644
index 00000000..79fed289
--- /dev/null
+++ b/scripts/automated_ingestion/s3_bucket.py
@@ -0,0 +1,187 @@
+import os
+from pathlib import Path
+from typing import Dict, Optional
+
+import boto3
+from botocore.exceptions import ClientError
+from utils import log_function_entry_exit, log_message, LoggingScope
+from remote_storage import RemoteStorageClient
+
+
+class EESSIS3Bucket(RemoteStorageClient):
+ """EESSI-specific S3 bucket implementation of the RemoteStorageClient protocol."""
+
+ @log_function_entry_exit()
+ def __init__(self, config, bucket_name: str):
+ """
+ Initialize the EESSI S3 bucket.
+
+ Args:
+ config: Configuration object containing:
+ - aws.access_key_id: AWS access key ID (optional, can use AWS_ACCESS_KEY_ID env var)
+ - aws.secret_access_key: AWS secret access key (optional, can use AWS_SECRET_ACCESS_KEY env var)
+ - aws.endpoint_url: Custom endpoint URL for S3-compatible backends (optional)
+ - aws.verify: SSL verification setting (optional)
+ - True: Verify SSL certificates (default)
+ - False: Skip SSL certificate verification
+ - str: Path to CA bundle file
+ bucket_name: Name of the S3 bucket to use
+ """
+ self.bucket = bucket_name
+
+ # Get AWS credentials from environment or config
+ aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID') or config.get('secrets', 'aws_access_key_id')
+ aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY') or config.get('secrets', 'aws_secret_access_key')
+
+ # Configure boto3 client
+ client_config = {}
+
+ # Add endpoint URL if specified in config
+ if config.has_option('aws', 'endpoint_url'):
+ client_config['endpoint_url'] = config['aws']['endpoint_url']
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Using custom endpoint URL: %s", client_config['endpoint_url'])
+
+ # Add SSL verification if specified in config
+ if config.has_option('aws', 'verify'):
+ verify = config['aws']['verify']
+ if verify.lower() == 'false':
+ client_config['verify'] = False
+ log_message(LoggingScope.DEBUG, 'WARNING', "SSL verification disabled")
+ elif verify.lower() == 'true':
+ client_config['verify'] = True
+ else:
+ client_config['verify'] = verify # Assume it's a path to CA bundle
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Using custom CA bundle: %s", verify)
+
+ self.client = boto3.client(
+ 's3',
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ **client_config
+ )
+ log_message(LoggingScope.DEBUG, 'INFO', "Initialized S3 client for bucket: %s", self.bucket)
+
+ def list_objects_v2(self, **kwargs):
+ """
+ List objects in the bucket using the underlying boto3 client.
+
+ Args:
+ **kwargs: Additional arguments to pass to boto3.client.list_objects_v2
+
+ Returns:
+ Response from boto3.client.list_objects_v2
+ """
+ return self.client.list_objects_v2(Bucket=self.bucket, **kwargs)
+
+ def download_file(self, key: str, filename: str) -> None:
+ """
+ Download a file from S3 to a local file.
+
+ Args:
+ key: The S3 key of the file to download
+ filename: The local path where the file should be saved
+ """
+ self.client.download_file(self.bucket, key, filename)
+
+ @log_function_entry_exit()
+ def get_metadata(self, remote_path: str) -> Dict:
+ """
+ Get metadata about an S3 object.
+
+ Args:
+ remote_path: Path to the object in S3
+
+ Returns:
+ Dictionary containing object metadata, including 'ETag' key
+ """
+ try:
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Getting metadata for S3 object: %s", remote_path)
+ response = self.client.head_object(Bucket=self.bucket, Key=remote_path)
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Retrieved metadata for %s: %s", remote_path, response)
+ return response
+ except ClientError as err:
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to get metadata for %s: %s", remote_path, str(err))
+ raise
+
+ def _get_etag_file_path(self, local_path: str) -> Path:
+ """Get the path to the .etag file for a given local file."""
+ return Path(local_path).with_suffix('.etag')
+
+ def _read_etag(self, local_path: str) -> Optional[str]:
+ """Read the ETag from the .etag file if it exists."""
+ etag_path = self._get_etag_file_path(local_path)
+ if etag_path.exists():
+ try:
+ with open(etag_path, 'r') as f:
+ return f.read().strip()
+ except Exception as e:
+ log_message(LoggingScope.DEBUG, 'WARNING', "Failed to read ETag file %s: %s", etag_path, str(e))
+ return None
+ return None
+
+ def _write_etag(self, local_path: str, etag: str) -> None:
+ """Write the ETag to the .etag file."""
+ etag_path = self._get_etag_file_path(local_path)
+ try:
+ with open(etag_path, 'w') as f:
+ f.write(etag)
+ log_message(LoggingScope.DEBUG, 'DEBUG', "Wrote ETag to %s", etag_path)
+ except Exception as e:
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to write ETag file %s: %s", etag_path, str(e))
+ # If we can't write the etag file, it's not critical
+ # The file will just be downloaded again next time
+
+ @log_function_entry_exit()
+ def download(self, remote_path: str, local_path: str) -> None:
+ """
+ Download an S3 object to a local location and store its ETag.
+
+ Args:
+ remote_path: Path to the object in S3
+ local_path: Local path where to save the file
+ """
+ try:
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Downloading %s to %s", remote_path, local_path)
+ self.client.download_file(Bucket=self.bucket, Key=remote_path, Filename=local_path)
+ log_message(LoggingScope.DOWNLOAD, 'INFO', "Successfully downloaded %s to %s", remote_path, local_path)
+ except ClientError as err:
+ log_message(LoggingScope.ERROR, 'ERROR', "Failed to download %s: %s", remote_path, str(err))
+ raise
+
+ # Get metadata first to obtain the ETag
+ metadata = self.get_metadata(remote_path)
+ etag = metadata['ETag']
+
+ # Store the ETag
+ self._write_etag(local_path, etag)
+
+ @log_function_entry_exit()
+ def get_bucket_url(self) -> str:
+ """
+ Get the HTTPS URL for a bucket from an initialized boto3 client.
+ Works with both AWS S3 and MinIO/S3-compatible services.
+ """
+ try:
+ # Check if this is a custom endpoint (MinIO) or AWS S3
+ endpoint_url = self.client.meta.endpoint_url
+
+ if endpoint_url:
+ # Custom endpoint (MinIO, DigitalOcean Spaces, etc.)
+ # Most S3-compatible services use path-style URLs
+ bucket_url = f"{endpoint_url}/{self.bucket}"
+
+ else:
+ # AWS S3 (no custom endpoint specified)
+ region = self.client.meta.region_name or 'us-east-1'
+
+ # AWS S3 virtual-hosted-style URLs
+ if region == 'us-east-1':
+ bucket_url = f"https://{self.bucket}.s3.amazonaws.com"
+ else:
+ bucket_url = f"https://{self.bucket}.s3.{region}.amazonaws.com"
+
+ return bucket_url
+
+ except Exception as err:
+ log_message(LoggingScope.ERROR, 'ERROR', "Error getting bucket URL: %s", str(err))
+ return None
diff --git a/scripts/automated_ingestion/utils.py b/scripts/automated_ingestion/utils.py
index 66843dd9..4b867764 100644
--- a/scripts/automated_ingestion/utils.py
+++ b/scripts/automated_ingestion/utils.py
@@ -1,6 +1,104 @@
import hashlib
import json
import requests
+import logging
+import functools
+import time
+import os
+import inspect
+from enum import IntFlag, auto
+import sys
+
+
+class LoggingScope(IntFlag):
+ """Enumeration of different logging scopes."""
+ NONE = 0
+ FUNC_ENTRY_EXIT = auto() # Function entry/exit logging
+ DOWNLOAD = auto() # Logging related to file downloads
+ VERIFICATION = auto() # Logging related to signature and checksum verification
+ STATE_OPS = auto() # Logging related to tarball state operations
+ GITHUB_OPS = auto() # Logging related to GitHub operations (PRs, issues, etc.)
+ GROUP_OPS = auto() # Logging related to tarball group operations
+ TASK_OPS = auto() # Logging related to task operations
+ ERROR = auto() # Error logging (separate from other scopes for easier filtering)
+ DEBUG = auto() # Debug-level logging (separate from other scopes for easier filtering)
+ ALL = (FUNC_ENTRY_EXIT | DOWNLOAD | VERIFICATION | STATE_OPS |
+ GITHUB_OPS | GROUP_OPS | TASK_OPS | ERROR | DEBUG)
+
+
+# Global setting for logging scopes
+ENABLED_LOGGING_SCOPES = LoggingScope.NONE
+
+
+# Global variable to track call stack depth
+_call_stack_depth = 0
+
+
+def set_logging_scopes(scopes):
+ """
+ Set the enabled logging scopes.
+
+ Args:
+ scopes: Can be:
+ - A LoggingScope value
+ - A string with comma-separated values using +/- syntax:
+ - "+SCOPE" to enable a scope
+ - "-SCOPE" to disable a scope
+ - "ALL" or "+ALL" to enable all scopes
+ - "-ALL" to disable all scopes
+ Examples:
+ "+FUNC_ENTRY_EXIT" # Enable only function entry/exit
+ "+FUNC_ENTRY_EXIT,-EXAMPLE_SCOPE" # Enable function entry/exit but disable example
+ "+ALL,-FUNC_ENTRY_EXIT" # Enable all scopes except function entry/exit
+ """
+ global ENABLED_LOGGING_SCOPES
+
+ if isinstance(scopes, LoggingScope):
+ ENABLED_LOGGING_SCOPES = scopes
+ return
+
+ if isinstance(scopes, str):
+ # Start with no scopes enabled
+ ENABLED_LOGGING_SCOPES = LoggingScope.NONE
+
+ # Split into individual scope specifications
+ scope_specs = [s.strip() for s in scopes.split(",")]
+
+ for spec in scope_specs:
+ if not spec:
+ continue
+
+ # Check for ALL special case
+ if spec.upper() in ["ALL", "+ALL"]:
+ ENABLED_LOGGING_SCOPES = LoggingScope.ALL
+ continue
+ elif spec.upper() == "-ALL":
+ ENABLED_LOGGING_SCOPES = LoggingScope.NONE
+ continue
+
+ # Parse scope name and operation
+ operation = spec[0]
+ scope_name = spec[1:].strip().upper()
+
+ try:
+ scope_enum = LoggingScope[scope_name]
+ if operation == '+':
+ ENABLED_LOGGING_SCOPES |= scope_enum
+ elif operation == '-':
+ ENABLED_LOGGING_SCOPES &= ~scope_enum
+ else:
+ logging.warning(f"Invalid operation '{operation}' in scope specification: {spec}")
+ except KeyError:
+ logging.warning(f"Unknown logging scope: {scope_name}")
+
+ elif isinstance(scopes, list):
+ # Convert list to comma-separated string and process
+ set_logging_scopes(",".join(scopes))
+
+
+def is_logging_scope_enabled(scope):
+ """Check if a specific logging scope is enabled."""
+ return bool(ENABLED_LOGGING_SCOPES & scope)
def send_slack_message(webhook, msg):
@@ -25,3 +123,153 @@ def sha256sum(path):
for byte_block in iter(lambda: f.read(8192), b''):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
+
+
+def log_function_entry_exit(logger=None):
+ """
+ Decorator that logs function entry and exit with timing information.
+ Only logs if the FUNC_ENTRY_EXIT scope is enabled.
+
+ Args:
+ logger: Optional logger instance. If not provided, uses the module's logger.
+ """
+ def decorator(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ global _call_stack_depth
+
+ if not is_logging_scope_enabled(LoggingScope.FUNC_ENTRY_EXIT):
+ return func(*args, **kwargs)
+
+ if logger is None:
+ log = logging.getLogger(func.__module__)
+ else:
+ log = logger
+
+ # Get context information if available
+ context = ""
+ if len(args) > 0 and hasattr(args[0], 'object'):
+ # For EessiTarball methods, show the tarball name and state
+ tarball = args[0]
+ filename = os.path.basename(tarball.object)
+
+ # Format filename to show important parts
+ if len(filename) > 30:
+ parts = filename.split('-')
+ if len(parts) >= 6: # Ensure we have all required parts
+ # Get version, component, last part of architecture, and epoch
+ version = parts[1]
+ component = parts[2]
+ arch_last = parts[-2].split('-')[-1] # Last part of architecture
+ epoch = parts[-1] # includes file extension
+ filename = f"{version}-{component}-{arch_last}-{epoch}"
+ else:
+ # Fallback to simple truncation if format doesn't match
+ filename = f"{filename[:15]}...{filename[-12:]}"
+
+ context = f" [{filename}"
+ if hasattr(tarball, 'state'):
+ context += f" in {tarball.state}"
+ context += "]"
+
+ # Create indentation based on call stack depth
+ indent = " " * _call_stack_depth
+
+ # Get file name and line number where the function is defined
+ file_name = os.path.basename(inspect.getsourcefile(func))
+ source_lines, start_line = inspect.getsourcelines(func)
+ # Find the line with the actual function definition
+ def_line = next(i for i, line in enumerate(source_lines) if line.strip().startswith('def '))
+ def_line_no = start_line + def_line
+ # Find the last non-empty line of the function
+ last_line = next(i for i, line in enumerate(reversed(source_lines)) if line.strip())
+ last_line_no = start_line + len(source_lines) - 1 - last_line
+
+ start_time = time.time()
+ log.info(f"{indent}[FUNC_ENTRY_EXIT] Entering {func.__name__} at {file_name}:{def_line_no}{context}")
+ _call_stack_depth += 1
+ try:
+ result = func(*args, **kwargs)
+ _call_stack_depth -= 1
+ end_time = time.time()
+ # For normal returns, show the last line of the function
+ log.info(f"{indent}[FUNC_ENTRY_EXIT] Leaving {func.__name__} at {file_name}:{last_line_no}"
+ f"{context} (took {end_time - start_time:.2f}s)")
+ return result
+ except Exception as err:
+ _call_stack_depth -= 1
+ end_time = time.time()
+ # For exceptions, try to get the line number from the exception
+ try:
+ exc_line_no = err.__traceback__.tb_lineno
+ except AttributeError:
+ exc_line_no = last_line_no
+ log.info(f"{indent}[FUNC_ENTRY_EXIT] Leaving {func.__name__} at {file_name}:{exc_line_no}"
+ f"{context} with exception (took {end_time - start_time:.2f}s)")
+ raise err
+ return wrapper
+ return decorator
+
+
+def log_message(scope, level, msg, *args, logger=None, **kwargs):
+ """
+ Log a message if either:
+ 1. The specified scope is enabled, OR
+ 2. The current log level is equal to or higher than the specified level
+
+ Args:
+ scope: LoggingScope value indicating which scope this logging belongs to
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
+ msg: Message to log
+ logger: Optional logger instance. If not provided, uses the root logger.
+ *args, **kwargs: Additional arguments to pass to the logging function
+ """
+ log = logger or logging.getLogger()
+ log_level = getattr(logging, level.upper())
+
+ # Check if either condition is met
+ if not (is_logging_scope_enabled(scope) or log_level >= log.getEffectiveLevel()):
+ return
+
+ # Create indentation based on call stack depth
+ indent = " " * _call_stack_depth
+ # Add scope to the message
+ scoped_msg = f"[{scope.name}] {msg}"
+ indented_msg = f"{indent}{scoped_msg}"
+
+ # If scope is enabled, use the temporary handler
+ if is_logging_scope_enabled(scope):
+ # Save original handlers
+ original_handlers = list(log.handlers)
+
+ # Create a temporary handler that accepts all levels
+ temp_handler = logging.StreamHandler(sys.stdout)
+ temp_handler.setLevel(logging.DEBUG)
+ temp_handler.setFormatter(logging.Formatter('%(levelname)-8s: %(message)s'))
+
+ try:
+ # Remove existing handlers temporarily
+ for handler in original_handlers:
+ log.removeHandler(handler)
+
+ # Add temporary handler
+ log.addHandler(temp_handler)
+
+ # Log the message
+ log_func = getattr(log, level.lower())
+ log_func(indented_msg, *args, **kwargs)
+ finally:
+ log.removeHandler(temp_handler)
+ # Restore original handlers
+ for handler in original_handlers:
+ if handler not in log.handlers:
+ log.addHandler(handler)
+ # Only use normal logging if scope is not enabled AND level is high enough
+ elif not is_logging_scope_enabled(scope) and log_level >= log.getEffectiveLevel():
+ # Use normal logging with level check
+ log_func = getattr(log, level.lower())
+ log_func(indented_msg, *args, **kwargs)
+
+# Example usage:
+# log_message(LoggingScope.DOWNLOAD, 'INFO', "Downloading file: %s", filename)
+# log_message(LoggingScope.ERROR, 'ERROR', "Failed to download: %s", error_msg)
diff --git a/scripts/check-stratum-servers.py b/scripts/check-stratum-servers.py
index de4270d9..4e35b09e 100755
--- a/scripts/check-stratum-servers.py
+++ b/scripts/check-stratum-servers.py
@@ -9,7 +9,8 @@
import yaml
# Default location for EESSI's Ansible group vars file containing the CVMFS settings.
-DEFAULT_ANSIBLE_GROUP_VARS_LOCATION = 'https://raw.githubusercontent.com/EESSI/filesystem-layer/main/inventory/group_vars/all.yml'
+DEFAULT_ANSIBLE_GROUP_VARS_LOCATION = \
+ 'https://raw.githubusercontent.com/EESSI/filesystem-layer/main/inventory/group_vars/all.yml'
# Default fully qualified CVMFS repository name
DEFAULT_CVMFS_FQRN = 'software.eessi.io'
# Maximum amount of time (in minutes) that a Stratum 1 is allowed to not having performed a snapshot.
@@ -32,8 +33,8 @@ def find_stratum_urls(vars_file, fqrn):
"""Find all Stratum 0/1 URLs in a given Ansible YAML vars file that contains the EESSI CVMFS configuration."""
try:
group_vars = urllib.request.urlopen(vars_file)
- except:
- error(f'Cannot read the file that contains the Stratum 1 URLs from {vars_file}!')
+ except Exception as err:
+ error(f'Cannot read the file that contains the Stratum 1 URLs from {vars_file}!\nException: {err}')
try:
group_vars_yaml = yaml.safe_load(group_vars)
s1_urls = group_vars_yaml['eessi_cvmfs_server_urls'][0]['urls']
@@ -44,8 +45,8 @@ def find_stratum_urls(vars_file, fqrn):
break
else:
error(f'Could not find Stratum 0 URL in {vars_file}!')
- except:
- error(f'Cannot parse the yaml file from {vars_file}!')
+ except Exception as err:
+ error(f'Cannot parse the yaml file from {vars_file}!\nException: {err}')
return s0_url, s1_urls
@@ -64,7 +65,7 @@ def check_revisions(stratum_urls, fqrn):
revisions[stratum] = int(rev_matches[0])
else:
errors.append(f'Could not find revision number for stratum {stratum}!')
- except urllib.error.HTTPError as e:
+ except urllib.error.HTTPError:
errors.append(f'Could not connect to {stratum}!')
# Check if all revisions are the same.
@@ -95,10 +96,11 @@ def check_snapshots(s1_urls, fqrn, max_snapshot_delay=DEFAULT_MAX_SNAPSHOT_DELAY
# Stratum 1 servers are supposed to make a snapshot every few minutes,
# so let's check if it is not too far behind.
if now - last_snapshot_time > datetime.timedelta(minutes=max_snapshot_delay):
+ time_diff = (now - last_snapshot_time).seconds / 60
errors.append(
- f'Stratum 1 {s1} has made its last snapshot {(now - last_snapshot_time).seconds / 60:.0f} minutes ago!')
- except urllib.error.HTTPError as e:
- errors.append(f'Could not connect to {s1_json}!')
+ f'Stratum 1 {s1} has made its last snapshot {time_diff:.0f} minutes ago!')
+ except urllib.error.HTTPError:
+ errors.append(f'Could not connect to {s1_snapshot_file}!')
if last_snapshots:
# Get the Stratum 1 with the most recent snapshot...