diff --git a/aws_glue_etl_docker/glueshim.py b/aws_glue_etl_docker/glueshim.py index 3337cd6..81cd546 100644 --- a/aws_glue_etl_docker/glueshim.py +++ b/aws_glue_etl_docker/glueshim.py @@ -3,13 +3,33 @@ import shutil import glob import pyspark +import logging from pyspark import SparkConf, SparkContext, SQLContext from pprint import pprint def _load_data(filePaths, dataset_name, spark_context, groupfiles, groupsize): sqlContext = SQLContext(spark_context) + + return (sqlContext + .read + .option("multiLine", "true") + .option("inferSchema", "true") + .json(filePaths)) + return sqlContext.read.json(filePaths) + +def _load_data_from_catalog(database, table_name, fileType, spark_context): + if fileType == "csv": + sqlContext = SQLContext(spark_context) + return (sqlContext + .read + .option("header", "true") + .option("mode", "DROPMALFORMED") + .csv(database)) + else: + return _load_data(database, table_name, spark_context, '', '') + def _write_csv(dataframe, bucket, location, dataset_name, spark_context): output_path = '/data/' + bucket + '/' + location @@ -31,7 +51,7 @@ def _get_spark_context(): return (pyspark.SparkContext.getOrCreate(), None) def _get_all_files_with_prefix(bucket, prefix, spark_context): - pathToWalk = '/data/' + bucket + '/' + prefix +'**/*.*' + pathToWalk = '/data/' + bucket + '/' + prefix +'**/*.*' return glob.glob(pathToWalk,recursive=True) def _is_in_aws(): @@ -40,6 +60,9 @@ def _is_in_aws(): def _get_arguments(default): return default +def _get_logger(spark_context): + return logging.getLogger("broadcast") + def _finish(self): return None @@ -50,7 +73,7 @@ def _finish(self): from awsglue.dynamicframe import DynamicFrame from awsglue.job import Job import boto3 - + def _load_data(file_paths, dataset_name, context, groupfiles, groupsize): connection_options = {'paths': file_paths} @@ -67,7 +90,12 @@ def _load_data(file_paths, dataset_name, context, groupfiles, groupsize): transformation_ctx=dataset_name) return glue0.toDF() - + + def _load_data_from_catalog(database, table_name, type, context): + dynamic_frame = context.create_dynamic_frame.from_catalog( + database=database, table_name=table_name) + return dynamic_frame.toDF() + def _write_csv(dataframe, bucket, location, dataset_name, spark_context): output_path = "s3://" + bucket + "/" + location df_tmp = DynamicFrame.fromDF(dataframe.repartition(1), spark_context, dataset_name) @@ -88,11 +116,11 @@ def _delete_files_with_prefix(bucket, prefix): for obj in page['Contents']: if not obj['Key'].endswith('/'): delete_keys['Objects'].append({'Key': str(obj['Key'])}) - + s3.delete_objects(Bucket=bucket, Delete=delete_keys) delete_keys = {'Objects' : []} - + def _write_parquet(dataframe, bucket, location, partition_columns, dataset_name, spark_context): if "job-bookmark-disable" in sys.argv: _delete_files_with_prefix(bucket, location) @@ -115,7 +143,7 @@ def _get_spark_context(): job = Job(spark_context) args = _get_arguments({}) job.init(args['JOB_NAME'], args) - + return (spark_context, job) def _get_all_files_with_prefix(bucket, prefix, spark_context): @@ -129,11 +157,14 @@ def _get_all_files_with_prefix(bucket, prefix, spark_context): if not obj['Key'].endswith('/') and '/' in obj['Key']: idx = obj['Key'].rfind('/') prefixes.add('s3://{}/{}'.format(bucket, obj['Key'][0:idx])) - + return list(prefixes) - + def _get_arguments(defaults): - return getResolvedOptions(sys.argv, ['JOB_NAME'] + defaults.keys()) + return getResolvedOptions(sys.argv, ['JOB_NAME'] + defaults.keys()) + + def _get_logger(spark_context): + return spark_context.get_logger() def _is_in_aws(): return True @@ -145,26 +176,31 @@ def _finish(self): except NameError: print("unable to commit job") - + except Exception as e: print('local dev') - -class GlueShim: + +class GlueShim: def __init__(self): c = _get_spark_context() self.spark_context = c[0] self.job = c[1] self._groupfiles = None self._groupsize = None - + def arguments(self, defaults): """Gets the arguments for a job. When running in glue, the response is pulled form sys.argv Keyword arguments: defaults -- default dictionary of options """ - return _get_arguments(defaults) - + return _get_arguments(defaults) + + def get_logger(self ): + """Gets the default logger for the job + """ + return _get_logger(self.spark_context) + def load_data(self, file_paths, dataset_name): """Loads data into a dataframe @@ -173,7 +209,16 @@ def load_data(self, file_paths, dataset_name): dataset_name -- name of this dataset, used for glue bookmarking """ return _load_data(file_paths, dataset_name, self.spark_context, self._groupfiles, self._groupsize) - + + def load_data_from_catalog(self, database, table_name, type = "csv"): + """Loads data into a dataframe from the glue catalog + + Keyword arguments: + database -- the glue database to read from + table_name -- the table name to read + """ + return _load_data_from_catalog(database, table_name, type, self.spark_context) + def get_all_files_with_prefix(self, bucket, prefix): """Given a bucket and file prefix, this method will return a list of all files with that prefix @@ -195,7 +240,7 @@ def write_parquet(self, dataframe, bucket, location, partition_columns, dataset_ """ _write_parquet(dataframe, bucket, location, partition_columns, dataset_name, self.spark_context) - + def write_csv(self, dataframe, bucket, location, dataset_name): """Writes a dataframe in csv format with a partition count of 1 @@ -207,10 +252,10 @@ def write_csv(self, dataframe, bucket, location, dataset_name): """ _write_csv(dataframe, bucket, location, dataset_name, self.spark_context) - + def get_spark_context(self): """ Gets the spark context """ - return self.context + return self.spark_context def finish(self): """ Should be run at the end, will set Glue bookmarks """ @@ -219,7 +264,7 @@ def finish(self): def set_group_files(self, groupfiles): """ Sets extra options used with glue https://docs.aws.amazon.com/glue/latest/dg/grouping-input-files.html """ self._groupfiles = groupfiles - + def set_group_size(self, groupsize): """ Sets extra options used with glue https://docs.aws.amazon.com/glue/latest/dg/grouping-input-files.html """ self._groupsize = groupsize