From 0909b80dbf3a22424dc6068827f8782aa2e0b464 Mon Sep 17 00:00:00 2001 From: michaelneely <141048027+michaelneely@users.noreply.github.com> Date: Tue, 17 Dec 2024 15:15:10 +0000 Subject: [PATCH] change: add _to_athena_query method to dataset_builder --- .../feature_store/dataset_builder.py | 106 +++++++++++------- 1 file changed, 63 insertions(+), 43 deletions(-) diff --git a/src/sagemaker/feature_store/dataset_builder.py b/src/sagemaker/feature_store/dataset_builder.py index 289fa1ee0c..da7a0db5ca 100644 --- a/src/sagemaker/feature_store/dataset_builder.py +++ b/src/sagemaker/feature_store/dataset_builder.py @@ -438,54 +438,13 @@ def to_csv_file(self) -> Tuple[str, str]: os.remove(local_file_name) temp_table_name = f'dataframe_{temp_id.replace("-", "_")}' self._create_temp_table(temp_table_name, desired_s3_folder) - base_features = list(self._base.columns) - event_time_identifier_feature_dtype = self._base[ - self._event_time_identifier_feature_name - ].dtypes - self._event_time_identifier_feature_type = ( - FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( - str(event_time_identifier_feature_dtype), None - ) - ) - query_string = self._construct_query_string( - FeatureGroupToBeMerged( - base_features, - self._included_feature_names if self._included_feature_names else base_features, - self._included_feature_names if self._included_feature_names else base_features, - _DEFAULT_CATALOG, - _DEFAULT_DATABASE, - temp_table_name, - self._record_identifier_feature_name, - FeatureDefinition( - self._event_time_identifier_feature_name, - self._event_time_identifier_feature_type, - ), - None, - TableType.DATA_FRAME, - ) - ) - query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + query_result = self._run_query(*self._to_athena_query(temp_table_name=temp_table_name)) # TODO: cleanup temp table, need more clarification, keep it for now return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( "OutputLocation", None ), query_result.get("QueryExecution", {}).get("Query", None) if isinstance(self._base, FeatureGroup): - base_feature_group = construct_feature_group_to_be_merged( - self._base, self._included_feature_names - ) - self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name - self._event_time_identifier_feature_name = ( - base_feature_group.event_time_identifier_feature.feature_name - ) - self._event_time_identifier_feature_type = ( - base_feature_group.event_time_identifier_feature.feature_type - ) - query_string = self._construct_query_string(base_feature_group) - query_result = self._run_query( - query_string, - base_feature_group.catalog, - base_feature_group.database, - ) + query_result = self._run_query(*self._to_athena_query()) return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( "OutputLocation", None ), query_result.get("QueryExecution", {}).get("Query", None) @@ -1058,6 +1017,67 @@ def _construct_athena_table_column_string(self, column: str) -> str: raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.") return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}" + def _to_athena_query(self, temp_table_name: str = None) -> Tuple[str, str, str]: + """Internal method for constructing an Athena query. + + Args: + temp_table_name (str): The temporary Athena table name of the base pandas.DataFrame. Defaults to None. + + Returns: + The query string. + The name of the catalog to be used in the query execution. + The database to be used in the query execution. + + Raises: + ValueError: temp_table_name must be provided if the base is a pandas.DataFrame. + """ + if isinstance(self._base, pd.DataFrame): + if temp_table_name is None: + raise ValueError("temp_table_name must be provided for a pandas.DataFrame base.") + base_features = list(self._base.columns) + event_time_identifier_feature_dtype = self._base[ + self._event_time_identifier_feature_name + ].dtypes + self._event_time_identifier_feature_type = ( + FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( + str(event_time_identifier_feature_dtype), None + ) + ) + catalog = _DEFAULT_CATALOG + database = _DEFAULT_DATABASE + query_string = self._construct_query_string( + FeatureGroupToBeMerged( + base_features, + self._included_feature_names if self._included_feature_names else base_features, + self._included_feature_names if self._included_feature_names else base_features, + catalog, + database, + temp_table_name, + self._record_identifier_feature_name, + FeatureDefinition( + self._event_time_identifier_feature_name, + self._event_time_identifier_feature_type, + ), + None, + TableType.DATA_FRAME, + ) + ) + if isinstance(self._base, FeatureGroup): + base_feature_group = construct_feature_group_to_be_merged( + self._base, self._included_feature_names + ) + self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name + self._event_time_identifier_feature_name = ( + base_feature_group.event_time_identifier_feature.feature_name + ) + self._event_time_identifier_feature_type = ( + base_feature_group.event_time_identifier_feature.feature_type + ) + catalog = base_feature_group.catalog + database = base_feature_group.database + query_string = self._construct_query_string(base_feature_group) + return query_string, catalog, database + def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]: """Internal method for execute Athena query, wait for query finish and get query result.