Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Store Dataset builder: add _to_athena_query method #4969

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 63 additions & 43 deletions src/sagemaker/feature_store/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,54 +438,13 @@ def to_csv_file(self) -> Tuple[str, str]:
os.remove(local_file_name)
temp_table_name = f'dataframe_{temp_id.replace("-", "_")}'
self._create_temp_table(temp_table_name, desired_s3_folder)
base_features = list(self._base.columns)
event_time_identifier_feature_dtype = self._base[
self._event_time_identifier_feature_name
].dtypes
self._event_time_identifier_feature_type = (
FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
str(event_time_identifier_feature_dtype), None
)
)
query_string = self._construct_query_string(
FeatureGroupToBeMerged(
base_features,
self._included_feature_names if self._included_feature_names else base_features,
self._included_feature_names if self._included_feature_names else base_features,
_DEFAULT_CATALOG,
_DEFAULT_DATABASE,
temp_table_name,
self._record_identifier_feature_name,
FeatureDefinition(
self._event_time_identifier_feature_name,
self._event_time_identifier_feature_type,
),
None,
TableType.DATA_FRAME,
)
)
query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)
query_result = self._run_query(*self._to_athena_query(temp_table_name=temp_table_name))
# TODO: cleanup temp table, need more clarification, keep it for now
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
"OutputLocation", None
), query_result.get("QueryExecution", {}).get("Query", None)
if isinstance(self._base, FeatureGroup):
base_feature_group = construct_feature_group_to_be_merged(
self._base, self._included_feature_names
)
self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name
self._event_time_identifier_feature_name = (
base_feature_group.event_time_identifier_feature.feature_name
)
self._event_time_identifier_feature_type = (
base_feature_group.event_time_identifier_feature.feature_type
)
query_string = self._construct_query_string(base_feature_group)
query_result = self._run_query(
query_string,
base_feature_group.catalog,
base_feature_group.database,
)
query_result = self._run_query(*self._to_athena_query())
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
"OutputLocation", None
), query_result.get("QueryExecution", {}).get("Query", None)
Expand Down Expand Up @@ -1058,6 +1017,67 @@ def _construct_athena_table_column_string(self, column: str) -> str:
raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.")
return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}"

def _to_athena_query(self, temp_table_name: str = None) -> Tuple[str, str, str]:
"""Internal method for constructing an Athena query.

Args:
temp_table_name (str): The temporary Athena table name of the base pandas.DataFrame. Defaults to None.

Returns:
The query string.
The name of the catalog to be used in the query execution.
The database to be used in the query execution.

Raises:
ValueError: temp_table_name must be provided if the base is a pandas.DataFrame.
"""
if isinstance(self._base, pd.DataFrame):
if temp_table_name is None:
raise ValueError("temp_table_name must be provided for a pandas.DataFrame base.")
base_features = list(self._base.columns)
event_time_identifier_feature_dtype = self._base[
self._event_time_identifier_feature_name
].dtypes
self._event_time_identifier_feature_type = (
FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
str(event_time_identifier_feature_dtype), None
)
)
catalog = _DEFAULT_CATALOG
database = _DEFAULT_DATABASE
query_string = self._construct_query_string(
FeatureGroupToBeMerged(
base_features,
self._included_feature_names if self._included_feature_names else base_features,
self._included_feature_names if self._included_feature_names else base_features,
catalog,
database,
temp_table_name,
self._record_identifier_feature_name,
FeatureDefinition(
self._event_time_identifier_feature_name,
self._event_time_identifier_feature_type,
),
None,
TableType.DATA_FRAME,
)
)
if isinstance(self._base, FeatureGroup):
base_feature_group = construct_feature_group_to_be_merged(
self._base, self._included_feature_names
)
self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name
self._event_time_identifier_feature_name = (
base_feature_group.event_time_identifier_feature.feature_name
)
self._event_time_identifier_feature_type = (
base_feature_group.event_time_identifier_feature.feature_type
)
catalog = base_feature_group.catalog
database = base_feature_group.database
query_string = self._construct_query_string(base_feature_group)
return query_string, catalog, database

def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]:
"""Internal method for execute Athena query, wait for query finish and get query result.

Expand Down
Loading