Skip to content

Fix Null, UUID and df arrow table output format #308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chdb/dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def cursor(self, cursor=None):
return Cursor(self)
return Cursor(self)

def query(self, sql, fmt="ArrowStream"):
def query(self, sql, fmt="CSV"):
"""Execute a query and return the raw result."""
if self._closed:
raise err.InterfaceError("Connection closed")
Expand Down
48 changes: 30 additions & 18 deletions chdb/dbapi/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# executemany only supports simple bulk insert.
# You can use it to load large dataset.
RE_INSERT_VALUES = re.compile(
r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" +
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
re.IGNORECASE | re.DOTALL)
r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)"
+ r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))"
+ r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
re.IGNORECASE | re.DOTALL,
)


class Cursor(object):
Expand Down Expand Up @@ -131,13 +132,17 @@ def execute(self, query, args=None):

self._cursor.execute(query)

# Get description from Arrow schema
if self._cursor._current_table is not None:
# Get description from column names and types
if hasattr(self._cursor, "_column_names") and self._cursor._column_names:
self.description = [
(field.name, field.type.to_pandas_dtype(), None, None, None, None, None)
for field in self._cursor._current_table.schema
(name, type_info, None, None, None, None, None)
for name, type_info in zip(
self._cursor._column_names, self._cursor._column_types
)
]
self.rowcount = self._cursor._current_table.num_rows
self.rowcount = (
len(self._cursor._current_table) if self._cursor._current_table else -1
)
else:
self.description = None
self.rowcount = -1
Expand All @@ -164,16 +169,23 @@ def executemany(self, query, args):
if m:
q_prefix = m.group(1) % ()
q_values = m.group(2).rstrip()
q_postfix = m.group(3) or ''
assert q_values[0] == '(' and q_values[-1] == ')'
return self._do_execute_many(q_prefix, q_values, q_postfix, args,
self.max_stmt_length,
self._get_db().encoding)
q_postfix = m.group(3) or ""
assert q_values[0] == "(" and q_values[-1] == ")"
return self._do_execute_many(
q_prefix,
q_values,
q_postfix,
args,
self.max_stmt_length,
self._get_db().encoding,
)

self.rowcount = sum(self.execute(query, arg) for arg in args)
return self.rowcount

def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
def _do_execute_many(
self, prefix, values, postfix, args, max_stmt_length, encoding
):
conn = self._get_db()
escape = self._escape_args
if isinstance(prefix, str):
Expand All @@ -184,18 +196,18 @@ def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encod
args = iter(args)
v = values % escape(next(args), conn)
if isinstance(v, str):
v = v.encode(encoding, 'surrogateescape')
v = v.encode(encoding, "surrogateescape")
sql += v
rows = 0
for arg in args:
v = values % escape(arg, conn)
if isinstance(v, str):
v = v.encode(encoding, 'surrogateescape')
v = v.encode(encoding, "surrogateescape")
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
rows += self.execute(sql + postfix)
sql = prefix
else:
sql += ','.encode(encoding)
sql += ",".encode(encoding)
sql += v
rows += self.execute(sql + postfix)
self.rowcount = rows
Expand Down
167 changes: 155 additions & 12 deletions chdb/state/sqlitelike.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import io
from typing import Optional, Any
from chdb import _chdb

Expand All @@ -11,6 +10,36 @@
raise ImportError("Failed to import pyarrow") from None


_arrow_format = set({"dataframe", "arrowtable"})
_process_result_format_funs = {
"dataframe": lambda x: to_df(x),
"arrowtable": lambda x: to_arrowTable(x),
}


# return pyarrow table
def to_arrowTable(res):
"""convert res to arrow table"""
# try import pyarrow and pandas, if failed, raise ImportError with suggestion
try:
import pyarrow as pa # noqa
import pandas as pd # noqa
except ImportError as e:
print(f"ImportError: {e}")
print('Please install pyarrow and pandas via "pip install pyarrow pandas"')
raise ImportError("Failed to import pyarrow or pandas") from None
if len(res) == 0:
return pa.Table.from_batches([], schema=pa.schema([]))
return pa.RecordBatchFileReader(res.bytes()).read_all()


# return pandas dataframe
def to_df(r):
"""convert arrow table to Dataframe"""
t = to_arrowTable(r)
return t.to_pandas(use_threads=True)


class Connection:
def __init__(self, connection_string: str):
# print("Connection", connection_string)
Expand All @@ -22,7 +51,13 @@ def cursor(self) -> "Cursor":
return self._cursor

def query(self, query: str, format: str = "CSV") -> Any:
return self._conn.query(query, format)
lower_output_format = format.lower()
result_func = _process_result_format_funs.get(lower_output_format, lambda x: x)
if lower_output_format in _arrow_format:
format = "Arrow"

result = self._conn.query(query, format)
return result_func(result)

def close(self) -> None:
# print("close")
Expand All @@ -41,17 +76,103 @@ def __init__(self, connection):
def execute(self, query: str) -> None:
self._cursor.execute(query)
result_mv = self._cursor.get_memview()
# print("get_result", result_mv)
if self._cursor.has_error():
raise Exception(self._cursor.error_message())
if self._cursor.data_size() == 0:
self._current_table = None
self._current_row = 0
self._column_names = []
self._column_types = []
return
arrow_data = result_mv.tobytes()
reader = pa.ipc.open_stream(io.BytesIO(arrow_data))
self._current_table = reader.read_all()
self._current_row = 0

# Parse JSON data
json_data = result_mv.tobytes().decode("utf-8")
import json

try:
# First line contains column names
# Second line contains column types
# Following lines contain data
lines = json_data.strip().split("\n")
if len(lines) < 2:
self._current_table = None
self._current_row = 0
self._column_names = []
self._column_types = []
return

self._column_names = json.loads(lines[0])
self._column_types = json.loads(lines[1])

# Convert data rows
rows = []
for line in lines[2:]:
if not line.strip():
continue
row_data = json.loads(line)
converted_row = []
for val, type_info in zip(row_data, self._column_types):
# Handle NULL values first
if val is None:
converted_row.append(None)
continue

# Basic type conversion
try:
if type_info.startswith("Int") or type_info.startswith("UInt"):
converted_row.append(int(val))
elif type_info.startswith("Float"):
converted_row.append(float(val))
elif type_info == "Bool":
converted_row.append(bool(val))
elif type_info == "String" or type_info == "FixedString":
converted_row.append(str(val))
elif type_info.startswith("DateTime"):
from datetime import datetime

# Check if the value is numeric (timestamp)
val_str = str(val)
if val_str.replace(".", "").isdigit():
converted_row.append(datetime.fromtimestamp(float(val)))
else:
# Handle datetime string formats
if "." in val_str: # Has microseconds
converted_row.append(
datetime.strptime(
val_str, "%Y-%m-%d %H:%M:%S.%f"
)
)
else: # No microseconds
converted_row.append(
datetime.strptime(val_str, "%Y-%m-%d %H:%M:%S")
)
elif type_info.startswith("Date"):
from datetime import date, datetime

# Check if the value is numeric (days since epoch)
val_str = str(val)
if val_str.isdigit():
converted_row.append(
date.fromtimestamp(float(val) * 86400)
)
else:
# Handle date string format
converted_row.append(
datetime.strptime(val_str, "%Y-%m-%d").date()
)
else:
# For unsupported types, keep as string
converted_row.append(str(val))
except (ValueError, TypeError):
# If conversion fails, keep original value as string
converted_row.append(str(val))
rows.append(tuple(converted_row))

self._current_table = rows
self._current_row = 0

except json.JSONDecodeError as e:
raise Exception(f"Failed to parse JSON data: {e}")

def commit(self) -> None:
self._cursor.commit()
Expand All @@ -60,12 +181,10 @@ def fetchone(self) -> Optional[tuple]:
if not self._current_table or self._current_row >= len(self._current_table):
return None

row_dict = {
col: self._current_table.column(col)[self._current_row].as_py()
for col in self._current_table.column_names
}
# Now self._current_table is a list of row tuples
row = self._current_table[self._current_row]
self._current_row += 1
return tuple(row_dict.values())
return row

def fetchmany(self, size: int = 1) -> tuple:
if not self._current_table:
Expand Down Expand Up @@ -99,6 +218,30 @@ def __next__(self) -> tuple:
raise StopIteration
return row

def column_names(self) -> list:
"""Return a list of column names from the last executed query"""
return self._column_names if hasattr(self, "_column_names") else []

def column_types(self) -> list:
"""Return a list of column types from the last executed query"""
return self._column_types if hasattr(self, "_column_types") else []

@property
def description(self) -> list:
"""
Return a description of the columns as per DB-API 2.0
Returns a list of 7-item tuples, each containing:
(name, type_code, display_size, internal_size, precision, scale, null_ok)
where only name and type_code are provided
"""
if not hasattr(self, "_column_names") or not self._column_names:
return []

return [
(name, type_info, None, None, None, None, None)
for name, type_info in zip(self._column_names, self._column_types)
]


def connect(connection_string: str = ":memory:") -> Connection:
"""
Expand Down
4 changes: 2 additions & 2 deletions programs/local/LocalChdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,9 @@ void cursor_wrapper::execute(const std::string & query_str)
release_result();
global_query_obj = findQueryableObjFromQuery(query_str);

// Always use Arrow format internally
// Use JSONCompactEachRowWithNamesAndTypes format for better type support
py::gil_scoped_release release;
current_result = query_conn(conn->get_conn(), query_str.c_str(), "ArrowStream");
current_result = query_conn(conn->get_conn(), query_str.c_str(), "JSONCompactEachRowWithNamesAndTypes");
}


Expand Down
Loading