diff --git a/dask_sql/input_utils/hive.py b/dask_sql/input_utils/hive.py index ad55d0e2c..8204810e0 100644 --- a/dask_sql/input_utils/hive.py +++ b/dask_sql/input_utils/hive.py @@ -18,7 +18,7 @@ sqlalchemy = None from dask_sql.input_utils.base import BaseInputPlugin -from dask_sql.mappings import cast_column_type, sql_to_python_type +from dask_sql.mappings import cast_column_type, sql_to_python_type, sql_to_python_value logger = logging.Logger(__name__) @@ -131,15 +131,19 @@ def wrapped_read_function(location, column_information, **kwargs): _, ) = parsed + partition_column_information = { + col: sql_to_python_type(col_type.upper()) + for col, col_type in partition_column_information.items() + } + location = partition_table_information["Location"] table = wrapped_read_function( location, partition_column_information, **kwargs ) - # Now add the additional partition columns - partition_values = ast.literal_eval( - partition_table_information["Partition Value"] - ) + partition_values = partition_table_information["Partition Value"] + partition_values = partition_values[1 : len(partition_values) - 1] + partition_values = partition_values.split(",") logger.debug( f"Applying additional partition information as columns: {partition_information}" @@ -147,8 +151,12 @@ def wrapped_read_function(location, column_information, **kwargs): partition_id = 0 for partition_key, partition_type in partition_information.items(): - table[partition_key] = partition_values[partition_id] - table = cast_column_type(table, partition_key, partition_type) + value = partition_values[partition_id] + value = sql_to_python_value(partition_type.upper(), value) + table[partition_key] = value + + # partition_type = sql_to_python_type() + # table = cast_column_type(table, partition_key, partition_type) partition_id += 1 @@ -166,17 +174,21 @@ def _escape_partition(self, partition: str): # pragma: no cover Wrap anything but digits in quotes. Don't wrap the column name. """ contains_only_digits = re.compile(r"^\d+$") - - try: - k, v = partition.split("=") - if re.match(contains_only_digits, v): - escaped_value = v - else: - escaped_value = f'"{v}"' - return f"{k}={escaped_value}" - except ValueError: - logger.warning(f"{partition} didn't contain a `=`") - return partition + escaped_partition = [] + + for partition_part in partition.split("/"): + try: + k, v = partition_part.split("=") + if re.match(contains_only_digits, v): + escaped_value = v + else: + escaped_value = f'"{v}"' + escaped_partition.append(f"{k}={escaped_value}") + except ValueError: + logger.warning(f"{partition_part} didn't contain a `=`") + escaped_partition.append(partition_part) + + return ",".join(escaped_partition) def _parse_hive_table_description( self, diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 5df6a42bd..c9b044d64 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -52,6 +52,7 @@ "FLOAT": np.float32, "DECIMAL": np.float32, "BIGINT": np.int64, + "INT": np.int32, # Although not in the standard, makes compatibility easier "INTEGER": np.int32, "SMALLINT": np.int16, "TINYINT": np.int8, @@ -157,6 +158,14 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any: # NULL time return pd.NaT # pragma: no cover + if isinstance(literal_value, str): + if sql_type == "DATE": + return datetime.strptime(literal_value, "%Y-%m-%d") + elif sql_type.startswith("TIME("): + return datetime.strptime(literal_value, "%H:%M:%S %Y") + elif sql_type.startswith("TIMESTAMP("): + return datetime.strptime(literal_value) + tz = literal_value.getTimeZone().getID() assert str(tz) == "UTC", "The code can currently only handle UTC timezones"