|
1 | 1 | """Contains functions for data readers."""
|
2 | 2 | import json
|
| 3 | +import logging |
| 4 | +import os |
3 | 5 | import re
|
4 | 6 | import urllib
|
5 | 7 | from collections import OrderedDict
|
|
19 | 21 | cast,
|
20 | 22 | )
|
21 | 23 |
|
| 24 | +import boto3 |
| 25 | +import botocore |
22 | 26 | import dateutil
|
23 | 27 | import pandas as pd
|
24 | 28 | import pyarrow.parquet as pq
|
@@ -843,3 +847,125 @@ def url_to_bytes(url_as_string: Url, options: Dict) -> BytesIO:
|
843 | 847 |
|
844 | 848 | stream.seek(0)
|
845 | 849 | return stream
|
| 850 | + |
| 851 | + |
| 852 | +class S3Helper: |
| 853 | + """ |
| 854 | + A utility class for working with Amazon S3. |
| 855 | +
|
| 856 | + This class provides methods to check if a path is an S3 URI |
| 857 | + and to create an S3 client. |
| 858 | + """ |
| 859 | + |
| 860 | + @staticmethod |
| 861 | + def is_s3_uri(path: str, logger: logging.Logger) -> bool: |
| 862 | + """ |
| 863 | + Check if the given path is an S3 URI. |
| 864 | +
|
| 865 | + This function checks for common S3 URI prefixes "s3://" and "s3a://". |
| 866 | +
|
| 867 | + Args: |
| 868 | + path (str): The path to check for an S3 URI. |
| 869 | + logger (logging.Logger): The logger instance for logging. |
| 870 | +
|
| 871 | + Returns: |
| 872 | + bool: True if the path is an S3 URI, False otherwise. |
| 873 | + """ |
| 874 | + # Define the S3 URI prefixes to check |
| 875 | + s3_uri_prefixes = ["s3://", "s3a://"] |
| 876 | + path = path.strip() |
| 877 | + # Check if the path starts with any of the specified prefixes |
| 878 | + is_s3 = any(path.startswith(prefix) for prefix in s3_uri_prefixes) |
| 879 | + if not is_s3: |
| 880 | + logger.debug(f"'{path}' is not a valid S3 URI") |
| 881 | + |
| 882 | + return is_s3 |
| 883 | + |
| 884 | + @staticmethod |
| 885 | + def _create_boto3_client( |
| 886 | + aws_access_key_id: Optional[str], |
| 887 | + aws_secret_access_key: Optional[str], |
| 888 | + aws_session_token: Optional[str], |
| 889 | + region_name: Optional[str], |
| 890 | + ) -> boto3.client: |
| 891 | + return boto3.client( |
| 892 | + "s3", |
| 893 | + aws_access_key_id=aws_access_key_id, |
| 894 | + aws_secret_access_key=aws_secret_access_key, |
| 895 | + aws_session_token=aws_session_token, |
| 896 | + region_name=region_name, |
| 897 | + ) |
| 898 | + |
| 899 | + @staticmethod |
| 900 | + def create_s3_client( |
| 901 | + aws_access_key_id: Optional[str] = None, |
| 902 | + aws_secret_access_key: Optional[str] = None, |
| 903 | + aws_session_token: Optional[str] = None, |
| 904 | + region_name: Optional[str] = None, |
| 905 | + ) -> boto3.client: |
| 906 | + """ |
| 907 | + Create and return an S3 client. |
| 908 | +
|
| 909 | + Args: |
| 910 | + aws_access_key_id (str): The AWS access key ID. |
| 911 | + aws_secret_access_key (str): The AWS secret access key. |
| 912 | + aws_session_token (str): The AWS session token |
| 913 | + (optional, typically used for temporary credentials). |
| 914 | + region_name (str): The AWS region name (default is 'us-east-1'). |
| 915 | +
|
| 916 | + Returns: |
| 917 | + boto3.client: A S3 client instance. |
| 918 | + """ |
| 919 | + # Check if credentials are not provided |
| 920 | + # and use environment variables as fallback |
| 921 | + if aws_access_key_id is None: |
| 922 | + aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") |
| 923 | + if aws_secret_access_key is None: |
| 924 | + aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") |
| 925 | + if aws_session_token is None: |
| 926 | + aws_session_token = os.environ.get("AWS_SESSION_TOKEN") |
| 927 | + |
| 928 | + # Check if region is not provided and use environment variable as fallback |
| 929 | + if region_name is None: |
| 930 | + region_name = os.environ.get("AWS_REGION", "us-east-1") |
| 931 | + |
| 932 | + # Check if IAM roles for service accounts are available |
| 933 | + try: |
| 934 | + s3 = S3Helper._create_boto3_client( |
| 935 | + aws_access_key_id, aws_secret_access_key, aws_session_token, region_name |
| 936 | + ) |
| 937 | + except botocore.exceptions.NoCredentialsError: |
| 938 | + # IAM roles are not available, so fall back to provided credentials |
| 939 | + if aws_access_key_id is None or aws_secret_access_key is None: |
| 940 | + raise ValueError( |
| 941 | + "AWS access key ID and secret access key are required." |
| 942 | + ) |
| 943 | + s3 = S3Helper._create_boto3_client( |
| 944 | + aws_access_key_id, aws_secret_access_key, aws_session_token, region_name |
| 945 | + ) |
| 946 | + |
| 947 | + return s3 |
| 948 | + |
| 949 | + @staticmethod |
| 950 | + def get_s3_uri(s3_uri: str, s3_client: boto3.client) -> BytesIO: |
| 951 | + """ |
| 952 | + Download an object from an S3 URI and return its content as BytesIO. |
| 953 | +
|
| 954 | + Args: |
| 955 | + s3_uri (str): The S3 URI specifying the location of the object to download. |
| 956 | + s3_client (boto3.client): An initialized AWS S3 client |
| 957 | + for accessing the S3 service. |
| 958 | +
|
| 959 | + Returns: |
| 960 | + BytesIO: A BytesIO object containing the content of |
| 961 | + the downloaded S3 object. |
| 962 | + """ |
| 963 | + # Parse the S3 URI |
| 964 | + parsed_uri = urllib.parse.urlsplit(s3_uri) |
| 965 | + bucket_name = parsed_uri.netloc |
| 966 | + file_key = parsed_uri.path.lstrip("/") |
| 967 | + # Download the S3 object |
| 968 | + response = s3_client.get_object(Bucket=bucket_name, Key=file_key) |
| 969 | + |
| 970 | + # Return the object's content as BytesIO |
| 971 | + return BytesIO(response["Body"].read()) |
0 commit comments