Skip to content

Commit

Permalink
Dataset builder: add _to_athena_query method
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelneely committed Dec 17, 2024
1 parent 342fbbc commit 0f5da99
Showing 1 changed file with 63 additions and 43 deletions.
106 changes: 63 additions & 43 deletions src/sagemaker/feature_store/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,54 +438,13 @@ def to_csv_file(self) -> Tuple[str, str]:
os.remove(local_file_name)
temp_table_name = f'dataframe_{temp_id.replace("-", "_")}'
self._create_temp_table(temp_table_name, desired_s3_folder)
base_features = list(self._base.columns)
event_time_identifier_feature_dtype = self._base[
self._event_time_identifier_feature_name
].dtypes
self._event_time_identifier_feature_type = (
FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
str(event_time_identifier_feature_dtype), None
)
)
query_string = self._construct_query_string(
FeatureGroupToBeMerged(
base_features,
self._included_feature_names if self._included_feature_names else base_features,
self._included_feature_names if self._included_feature_names else base_features,
_DEFAULT_CATALOG,
_DEFAULT_DATABASE,
temp_table_name,
self._record_identifier_feature_name,
FeatureDefinition(
self._event_time_identifier_feature_name,
self._event_time_identifier_feature_type,
),
None,
TableType.DATA_FRAME,
)
)
query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)
query_result = self._run_query(*self._to_athena_query(temp_table_name=temp_table_name))
# TODO: cleanup temp table, need more clarification, keep it for now
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
"OutputLocation", None
), query_result.get("QueryExecution", {}).get("Query", None)
if isinstance(self._base, FeatureGroup):
base_feature_group = construct_feature_group_to_be_merged(
self._base, self._included_feature_names
)
self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name
self._event_time_identifier_feature_name = (
base_feature_group.event_time_identifier_feature.feature_name
)
self._event_time_identifier_feature_type = (
base_feature_group.event_time_identifier_feature.feature_type
)
query_string = self._construct_query_string(base_feature_group)
query_result = self._run_query(
query_string,
base_feature_group.catalog,
base_feature_group.database,
)
query_result = self._run_query(*self._to_athena_query())
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
"OutputLocation", None
), query_result.get("QueryExecution", {}).get("Query", None)
Expand Down Expand Up @@ -1058,6 +1017,67 @@ def _construct_athena_table_column_string(self, column: str) -> str:
raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.")
return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}"

def _to_athena_query(self, temp_table_name: str = None) -> Tuple[str, str, str]:
"""Internal method for constructing an Athena query.
Args:
temp_table_name (str): The temporary Athena table name of the base pandas.DataFrame. Defaults to None.
Returns:
The query string.
The name of the catalog to be used in the query execution.
The database to be used in the query execution.
Raises:
ValueError: temp_table_name must be provided if the base is a pandas.DataFrame.
"""
if isinstance(self._base, pd.DataFrame):
if temp_table_name is None:
raise ValueError("temp_table_name must be provided for a pandas.DataFrame base.")
base_features = list(self._base.columns)
event_time_identifier_feature_dtype = self._base[
self._event_time_identifier_feature_name
].dtypes
self._event_time_identifier_feature_type = (
FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
str(event_time_identifier_feature_dtype), None
)
)
catalog = _DEFAULT_CATALOG
database = _DEFAULT_DATABASE
query_string = self._construct_query_string(
FeatureGroupToBeMerged(
base_features,
self._included_feature_names if self._included_feature_names else base_features,
self._included_feature_names if self._included_feature_names else base_features,
catalog,
database,
temp_table_name,
self._record_identifier_feature_name,
FeatureDefinition(
self._event_time_identifier_feature_name,
self._event_time_identifier_feature_type,
),
None,
TableType.DATA_FRAME,
)
)
if isinstance(self._base, FeatureGroup):
base_feature_group = construct_feature_group_to_be_merged(
self._base, self._included_feature_names
)
self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name
self._event_time_identifier_feature_name = (
base_feature_group.event_time_identifier_feature.feature_name
)
self._event_time_identifier_feature_type = (
base_feature_group.event_time_identifier_feature.feature_type
)
catalog = base_feature_group.catalog
database = base_feature_group.database
query_string = self._construct_query_string(base_feature_group)
return query_string, catalog, database

def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]:
"""Internal method for execute Athena query, wait for query finish and get query result.
Expand Down

0 comments on commit 0f5da99

Please sign in to comment.