Skip to content

Use the correct hive partition type information (hacky solution) #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: partition_escape
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions dask_sql/input_utils/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
sqlalchemy = None

from dask_sql.input_utils.base import BaseInputPlugin
from dask_sql.mappings import cast_column_type, sql_to_python_type
from dask_sql.mappings import cast_column_type, sql_to_python_type, sql_to_python_value

logger = logging.Logger(__name__)

Expand Down Expand Up @@ -131,24 +131,32 @@ def wrapped_read_function(location, column_information, **kwargs):
_,
) = parsed

partition_column_information = {
col: sql_to_python_type(col_type.upper())
for col, col_type in partition_column_information.items()
}

location = partition_table_information["Location"]
table = wrapped_read_function(
location, partition_column_information, **kwargs
)

# Now add the additional partition columns
partition_values = ast.literal_eval(
partition_table_information["Partition Value"]
)
partition_values = partition_table_information["Partition Value"]
partition_values = partition_values[1 : len(partition_values) - 1]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug (blocking): Unless len(partition_values) is 1 (will it always?) then this will return a list object and the subsequent .split(',') call will fail.

partition_values = partition_values.split(",")

logger.debug(
f"Applying additional partition information as columns: {partition_information}"
)

partition_id = 0
for partition_key, partition_type in partition_information.items():
table[partition_key] = partition_values[partition_id]
table = cast_column_type(table, partition_key, partition_type)
value = partition_values[partition_id]
value = sql_to_python_value(partition_type.upper(), value)
table[partition_key] = value

# partition_type = sql_to_python_type()
# table = cast_column_type(table, partition_key, partition_type)

partition_id += 1

Expand All @@ -166,17 +174,21 @@ def _escape_partition(self, partition: str): # pragma: no cover
Wrap anything but digits in quotes. Don't wrap the column name.
"""
contains_only_digits = re.compile(r"^\d+$")

try:
k, v = partition.split("=")
if re.match(contains_only_digits, v):
escaped_value = v
else:
escaped_value = f'"{v}"'
return f"{k}={escaped_value}"
except ValueError:
logger.warning(f"{partition} didn't contain a `=`")
return partition
escaped_partition = []

for partition_part in partition.split("/"):
try:
k, v = partition_part.split("=")
if re.match(contains_only_digits, v):
escaped_value = v
else:
escaped_value = f'"{v}"'
escaped_partition.append(f"{k}={escaped_value}")
except ValueError:
logger.warning(f"{partition_part} didn't contain a `=`")
escaped_partition.append(partition_part)

return ",".join(escaped_partition)

def _parse_hive_table_description(
self,
Expand Down
9 changes: 9 additions & 0 deletions dask_sql/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"FLOAT": np.float32,
"DECIMAL": np.float32,
"BIGINT": np.int64,
"INT": np.int32, # Although not in the standard, makes compatibility easier
"INTEGER": np.int32,
"SMALLINT": np.int16,
"TINYINT": np.int8,
Expand Down Expand Up @@ -157,6 +158,14 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any:
# NULL time
return pd.NaT # pragma: no cover

if isinstance(literal_value, str):
if sql_type == "DATE":
return datetime.strptime(literal_value, "%Y-%m-%d")
elif sql_type.startswith("TIME("):
return datetime.strptime(literal_value, "%H:%M:%S %Y")
elif sql_type.startswith("TIMESTAMP("):
return datetime.strptime(literal_value)

tz = literal_value.getTimeZone().getID()
assert str(tz) == "UTC", "The code can currently only handle UTC timezones"

Expand Down