Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions accel.c
Original file line number Diff line number Diff line change
Expand Up @@ -4108,6 +4108,7 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
PyObject *py_colspec = NULL;
PyObject *py_str = NULL;
PyObject *py_blob = NULL;
PyObject **py_transformers = NULL;
Py_ssize_t length = 0;
uint64_t row_id = 0;
uint8_t is_null = 0;
Expand Down Expand Up @@ -4138,13 +4139,23 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)

colspec_l = PyObject_Length(py_colspec);
ctypes = malloc(sizeof(int) * colspec_l);
py_transformers = calloc(sizeof(PyObject*), colspec_l);

for (i = 0; i < colspec_l; i++) {
PyObject *py_cspec = PySequence_GetItem(py_colspec, i);
if (!py_cspec) goto error;
PyObject *py_ctype = PySequence_GetItem(py_cspec, 1);
if (!py_ctype) { Py_DECREF(py_cspec); goto error; }
ctypes[i] = (int)PyLong_AsLong(py_ctype);
py_transformers[i] = PySequence_GetItem(py_cspec, 2);
if (!py_transformers[i]) {
Py_DECREF(py_ctype);
Py_DECREF(py_cspec);
goto error;
}
if (py_transformers[i] == Py_None) {
py_transformers[i] = NULL;
}
Py_DECREF(py_ctype);
Py_DECREF(py_cspec);
if (PyErr_Occurred()) { goto error; }
Expand Down Expand Up @@ -4380,6 +4391,14 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
default:
goto error;
}

if (py_transformers[i]) {
PyObject *py_item = PyTuple_GetItem(py_row, i);
PyObject *py_transformed = PyObject_CallFunction(py_transformers[i], "O", py_item);
if (!py_transformed) goto error;
Py_DECREF(py_item);
CHECKRC(PyTuple_SetItem(py_row, i, py_transformed));
}
}

CHECKRC(PyList_Append(py_out_rows, py_row));
Expand All @@ -4389,6 +4408,12 @@ static PyObject *load_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)

exit:
if (ctypes) free(ctypes);
if (py_transformers) {
for (i = 0; i < colspec_l; i++) {
Py_XDECREF(py_transformers[i]);
}
free(py_transformers);
}

Py_XDECREF(py_row);

Expand All @@ -4412,6 +4437,7 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
PyObject *py_row_ids = NULL;
PyObject *py_row_ids_iter = NULL;
PyObject *py_item = NULL;
PyObject **py_transformers = NULL;
uint64_t row_id = 0;
uint8_t is_null = 0;
int8_t i8 = 0;
Expand Down Expand Up @@ -4459,12 +4485,26 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)

returns = malloc(sizeof(int) * n_cols);
if (!returns) goto error;
py_transformers = calloc(sizeof(PyObject*), n_cols);
if (!py_transformers) goto error;

for (i = 0; i < n_cols; i++) {
PyObject *py_item = PySequence_GetItem(py_returns, i);
if (!py_item) goto error;
returns[i] = (int)PyLong_AsLong(py_item);
Py_DECREF(py_item);
PyObject *py_cspec = PySequence_GetItem(py_returns, i);
if (!py_cspec) goto error;
PyObject *py_ctype = PySequence_GetItem(py_cspec, 1);
if (!py_ctype) { Py_DECREF(py_cspec); goto error; }
returns[i] = (int)PyLong_AsLong(py_ctype);
py_transformers[i] = PySequence_GetItem(py_cspec, 2);
if (!py_transformers[i]) {
Py_DECREF(py_ctype);
Py_DECREF(py_cspec);
goto error;
}
if (py_transformers[i] == Py_None) {
py_transformers[i] = NULL;
}
Py_DECREF(py_ctype);
Py_DECREF(py_cspec);
if (PyErr_Occurred()) { goto error; }
}

Expand Down Expand Up @@ -4504,6 +4544,13 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)
memcpy(out+out_idx, &is_null, 1);
out_idx += 1;

if (py_transformers[i]) {
PyObject *py_transformed = PyObject_CallFunction(py_transformers[i], "O", py_item);
if (!py_transformed) goto error;
Py_DECREF(py_item);
py_item = py_transformed;
}

switch (returns[i]) {
case MYSQL_TYPE_BIT:
// TODO
Expand Down Expand Up @@ -4702,6 +4749,12 @@ static PyObject *dump_rowdat_1(PyObject *self, PyObject *args, PyObject *kwargs)

exit:
if (returns) free(returns);
if (py_transformers) {
for (i = 0; i < n_cols; i++) {
Py_XDECREF(py_transformers[i]);
}
free(py_transformers);
}

Py_XDECREF(py_item);
Py_XDECREF(py_row_iter);
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ exclude =
docs/*
resources/*
licenses/*
max-complexity = 45
max-complexity = 50
max-line-length = 90
per-file-ignores =
singlestoredb/__init__.py:F401
Expand Down
65 changes: 30 additions & 35 deletions singlestoredb/functions/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
UDFType = Callable[..., Any]


def is_valid_type(obj: Any) -> bool:
def is_valid_object_type(obj: Any) -> bool:
"""Check if the object is a valid type for a schema definition."""
if not inspect.isclass(obj):
return False
Expand Down Expand Up @@ -52,48 +52,34 @@ def is_valid_callable(obj: Any) -> bool:

returns = utils.get_annotations(obj).get('return', None)

if inspect.isclass(returns) and issubclass(returns, str):
if inspect.isclass(returns) and issubclass(returns, SQLString):
return True

raise TypeError(
f'callable {obj} must return a str, '
f'but got {returns}',
)
return False


def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
def expand_types(args: Any) -> Any:
"""Expand the types for the function arguments / return values."""
if args is None:
return None

# SQL string
if isinstance(args, str):
return [args]

# General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
elif is_valid_type(args):
return args

# List of SQL strings or callables
elif isinstance(args, list):
new_args = []
for arg in args:
if isinstance(arg, str):
new_args.append(arg)
elif callable(arg):
new_args.append(arg())
else:
raise TypeError(f'unrecognized type for parameter: {arg}')
return new_args
return []

# Callable that returns a SQL string
elif is_valid_callable(args):
out = args()
if not isinstance(out, str):
raise TypeError(f'unrecognized type for parameter: {args}')
return [out]
is_list = True
if not isinstance(args, list):
is_list = False
args = [args]

raise TypeError(f'unrecognized type for parameter: {args}')
new_args = []
for arg in args:
if isinstance(arg, str):
new_args.append(arg)
elif is_valid_callable(arg):
new_args.append(arg())
else:
new_args.append(arg)

if not is_list:
return new_args[0]
return new_args


def _func(
Expand All @@ -106,6 +92,15 @@ def _func(
) -> UDFType:
"""Generic wrapper for UDF and TVF decorators."""

if isinstance(args, dict):
raise TypeError(
'The `args` parameter must be a list of data types, not a dict.',
)
if isinstance(returns, dict):
raise TypeError(
'The `returns` parameter must be a list of data types, not a dict.',
)

_singlestoredb_attrs = { # type: ignore
k: v for k, v in dict(
name=name,
Expand Down
15 changes: 12 additions & 3 deletions singlestoredb/functions/ext/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async def to_thread(
'float64': ft.DOUBLE,
'str': ft.STRING,
'bytes': -ft.STRING,
'json': ft.STRING,
}


Expand Down Expand Up @@ -586,7 +587,11 @@ def make_func(
dtype = x['dtype'].replace('?', '')
if dtype not in rowdat_1_type_map:
raise TypeError(f'no data type mapping for {dtype}')
colspec.append((x['name'], rowdat_1_type_map[dtype]))
colspec.append((
x['name'],
rowdat_1_type_map[dtype],
x.get('transformer', None),
))
info['colspec'] = colspec

# Setup return type
Expand All @@ -595,7 +600,11 @@ def make_func(
dtype = x['dtype'].replace('?', '')
if dtype not in rowdat_1_type_map:
raise TypeError(f'no data type mapping for {dtype}')
returns.append((x['name'], rowdat_1_type_map[dtype]))
returns.append((
x['name'],
rowdat_1_type_map[dtype],
x.get('transformer', None),
))
info['returns'] = returns

return do_func, info
Expand Down Expand Up @@ -1084,7 +1093,7 @@ async def __call__(

with timer('format_output'):
body = output_handler['dump'](
[x[1] for x in func_info['returns']], *result, # type: ignore
func_info['returns'], *result, # type: ignore
)

await send(output_handler['response'])
Expand Down
Loading
Loading