Skip to content

Commit 5de1be2

Browse files
committed
Add float_hook to json decoder
This adds support for a `float_hook` to the json decoder. If set, this hook will be called to decode any untyped JSON float values from their raw string representations. This may be used to change the default float parsing from returning ``float`` values to return ``decimal.Decimal`` values again. Since this is an uncommon option, it's only available on the Decoder, rather than the top-level ``msgspec.json.decode`` function.
1 parent ef7808b commit 5de1be2

File tree

4 files changed

+121
-5
lines changed

4 files changed

+121
-5
lines changed

msgspec/_core.c

+57-5
Original file line numberDiff line numberDiff line change
@@ -11092,6 +11092,22 @@ parse_number_nonfinite(
1109211092
return ms_post_decode_float(val, type, path, strict, true);
1109311093
}
1109411094

11095+
static MS_NOINLINE PyObject *
11096+
json_float_hook(
11097+
const char *buf, Py_ssize_t size, PathNode *path, PyObject *float_hook
11098+
) {
11099+
PyObject *str = PyUnicode_New(size, 127);
11100+
if (str == NULL) return NULL;
11101+
memcpy(ascii_get_buffer(str), buf, size);
11102+
PyObject *out = CALL_ONE_ARG(float_hook, str);
11103+
Py_DECREF(str);
11104+
if (out == NULL) {
11105+
ms_maybe_wrap_validation_error(path);
11106+
return NULL;
11107+
}
11108+
return out;
11109+
}
11110+
1109511111
static MS_INLINE PyObject *
1109611112
parse_number_inline(
1109711113
const unsigned char *p,
@@ -11101,6 +11117,7 @@ parse_number_inline(
1110111117
TypeNode *type,
1110211118
PathNode *path,
1110311119
bool strict,
11120+
PyObject *float_hook,
1110411121
bool from_str
1110511122
) {
1110611123
uint64_t mantissa = 0;
@@ -11286,6 +11303,9 @@ parse_number_inline(
1128611303
(char *)start, p - start, true, path, NULL
1128711304
);
1128811305
}
11306+
else if (MS_UNLIKELY(float_hook != NULL && type->types & MS_TYPE_ANY)) {
11307+
return json_float_hook((char *)start, p - start, path, float_hook);
11308+
}
1128911309
else {
1129011310
if (MS_UNLIKELY(exponent > 288 || exponent < -307)) {
1129111311
/* Exponent is out of bounds */
@@ -11363,6 +11383,7 @@ maybe_parse_number(
1136311383
type,
1136411384
path,
1136511385
strict,
11386+
NULL,
1136611387
true
1136711388
);
1136811389
return (*out != NULL || errmsg == NULL);
@@ -15403,6 +15424,7 @@ typedef struct JSONDecoderState {
1540315424
/* Configuration */
1540415425
TypeNode *type;
1540515426
PyObject *dec_hook;
15427+
PyObject *float_hook;
1540615428
bool strict;
1540715429

1540815430
/* Temporary scratch space */
@@ -15425,10 +15447,11 @@ typedef struct JSONDecoder {
1542515447
TypeNode *type;
1542615448
char strict;
1542715449
PyObject *dec_hook;
15450+
PyObject *float_hook;
1542815451
} JSONDecoder;
1542915452

1543015453
PyDoc_STRVAR(JSONDecoder__doc__,
15431-
"Decoder(type='Any', *, strict=True, dec_hook=None)\n"
15454+
"Decoder(type='Any', *, strict=True, dec_hook=None, float_hook=None)\n"
1543215455
"--\n"
1543315456
"\n"
1543415457
"A JSON decoder.\n"
@@ -15449,19 +15472,28 @@ PyDoc_STRVAR(JSONDecoder__doc__,
1544915472
" signature ``dec_hook(type: Type, obj: Any) -> Any``, where ``type`` is the\n"
1545015473
" expected message type, and ``obj`` is the decoded representation composed\n"
1545115474
" of only basic JSON types. This hook should transform ``obj`` into type\n"
15452-
" ``type``, or raise a ``NotImplementedError`` if unsupported."
15475+
" ``type``, or raise a ``NotImplementedError`` if unsupported.\n"
15476+
"float_hook : callable, optional\n"
15477+
" An optional callback for handling decoding untyped float literals. Should\n"
15478+
" have the signature ``float_hook(val: str) -> Any``, where ``val`` is the\n"
15479+
" raw string value of the JSON float. This hook is called to decode any\n"
15480+
" \"untyped\" float value (e.g. ``typing.Any`` typed). The default is\n"
15481+
" equivalent to ``float_hook=float``, where all untyped JSON floats are\n"
15482+
" decoded as python floats. Specifying ``float_hook=decimal.Decimal``\n"
15483+
" will decode all untyped JSON floats as decimals instead."
1545315484
);
1545415485
static int
1545515486
JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds)
1545615487
{
15457-
char *kwlist[] = {"type", "strict", "dec_hook", NULL};
15488+
char *kwlist[] = {"type", "strict", "dec_hook", "float_hook", NULL};
1545815489
MsgspecState *st = msgspec_get_global_state();
1545915490
PyObject *type = st->typing_any;
1546015491
PyObject *dec_hook = NULL;
15492+
PyObject *float_hook = NULL;
1546115493
int strict = 1;
1546215494

1546315495
if (!PyArg_ParseTupleAndKeywords(
15464-
args, kwds, "|O$pO", kwlist, &type, &strict, &dec_hook)
15496+
args, kwds, "|O$pOO", kwlist, &type, &strict, &dec_hook, &float_hook)
1546515497
) {
1546615498
return -1;
1546715499
}
@@ -15479,6 +15511,19 @@ JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds)
1547915511
}
1548015512
self->dec_hook = dec_hook;
1548115513

15514+
/* Handle float_hook */
15515+
if (float_hook == Py_None) {
15516+
float_hook = NULL;
15517+
}
15518+
if (float_hook != NULL) {
15519+
if (!PyCallable_Check(float_hook)) {
15520+
PyErr_SetString(PyExc_TypeError, "float_hook must be callable");
15521+
return -1;
15522+
}
15523+
Py_INCREF(float_hook);
15524+
}
15525+
self->float_hook = float_hook;
15526+
1548215527
/* Handle strict */
1548315528
self->strict = strict;
1548415529

@@ -15498,6 +15543,7 @@ JSONDecoder_traverse(JSONDecoder *self, visitproc visit, void *arg)
1549815543
if (out != 0) return out;
1549915544
Py_VISIT(self->orig_type);
1550015545
Py_VISIT(self->dec_hook);
15546+
Py_VISIT(self->float_hook);
1550115547
return 0;
1550215548
}
1550315549

@@ -15508,6 +15554,7 @@ JSONDecoder_dealloc(JSONDecoder *self)
1550815554
TypeNode_Free(self->type);
1550915555
Py_XDECREF(self->orig_type);
1551015556
Py_XDECREF(self->dec_hook);
15557+
Py_XDECREF(self->float_hook);
1551115558
Py_TYPE(self)->tp_free((PyObject *)self);
1551215559
}
1551315560

@@ -17551,7 +17598,7 @@ json_maybe_decode_number(JSONDecoderState *self, TypeNode *type, PathNode *path)
1755117598
PyObject *out = parse_number_inline(
1755217599
self->input_pos, self->input_end,
1755317600
&pout, &errmsg,
17554-
type, path, self->strict, false
17601+
type, path, self->strict, self->float_hook, false
1755517602
);
1755617603
self->input_pos = (unsigned char *)pout;
1755717604

@@ -18014,6 +18061,7 @@ msgspec_json_format(PyObject *self, PyObject *args, PyObject *kwargs)
1801418061

1801518062
/* Init decoder */
1801618063
dec.dec_hook = NULL;
18064+
dec.float_hook = NULL;
1801718065
dec.type = NULL;
1801818066
dec.scratch = NULL;
1801918067
dec.scratch_capacity = 0;
@@ -18095,6 +18143,7 @@ JSONDecoder_decode(JSONDecoder *self, PyObject *const *args, Py_ssize_t nargs)
1809518143
.type = self->type,
1809618144
.strict = self->strict,
1809718145
.dec_hook = self->dec_hook,
18146+
.float_hook = self->float_hook,
1809818147
.scratch = NULL,
1809918148
.scratch_capacity = 0,
1810018149
.scratch_len = 0
@@ -18161,6 +18210,7 @@ JSONDecoder_decode_lines(JSONDecoder *self, PyObject *const *args, Py_ssize_t na
1816118210
.type = self->type,
1816218211
.strict = self->strict,
1816318212
.dec_hook = self->dec_hook,
18213+
.float_hook = self->float_hook,
1816418214
.scratch = NULL,
1816518215
.scratch_capacity = 0,
1816618216
.scratch_len = 0
@@ -18237,6 +18287,7 @@ static PyMemberDef JSONDecoder_members[] = {
1823718287
{"type", T_OBJECT_EX, offsetof(JSONDecoder, orig_type), READONLY, "The Decoder type"},
1823818288
{"strict", T_BOOL, offsetof(JSONDecoder, strict), READONLY, "The Decoder strict setting"},
1823918289
{"dec_hook", T_OBJECT, offsetof(JSONDecoder, dec_hook), READONLY, "The Decoder dec_hook"},
18290+
{"float_hook", T_OBJECT, offsetof(JSONDecoder, float_hook), READONLY, "The Decoder float_hook"},
1824018291
{NULL},
1824118292
};
1824218293

@@ -18334,6 +18385,7 @@ msgspec_json_decode(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyO
1833418385
JSONDecoderState state = {
1833518386
.strict = strict,
1833618387
.dec_hook = dec_hook,
18388+
.float_hook = NULL,
1833718389
.scratch = NULL,
1833818390
.scratch_capacity = 0,
1833918391
.scratch_len = 0

msgspec/json.pyi

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ T = TypeVar("T")
1818

1919
enc_hook_sig = Optional[Callable[[Any], Any]]
2020
dec_hook_sig = Optional[Callable[[type, Any], Any]]
21+
float_hook_sig = Optional[Callable[[str], Any]]
2122

2223
class Encoder:
2324
enc_hook: enc_hook_sig
@@ -41,13 +42,15 @@ class Decoder(Generic[T]):
4142
type: Type[T]
4243
strict: bool
4344
dec_hook: dec_hook_sig
45+
float_hook: float_hook_sig
4446

4547
@overload
4648
def __init__(
4749
self: Decoder[Any],
4850
*,
4951
strict: bool = True,
5052
dec_hook: dec_hook_sig = None,
53+
float_hook: float_hook_sig = None,
5154
) -> None: ...
5255
@overload
5356
def __init__(
@@ -56,6 +59,7 @@ class Decoder(Generic[T]):
5659
*,
5760
strict: bool = True,
5861
dec_hook: dec_hook_sig = None,
62+
float_hook: float_hook_sig = None,
5963
) -> None: ...
6064
@overload
6165
def __init__(
@@ -64,6 +68,7 @@ class Decoder(Generic[T]):
6468
*,
6569
strict: bool = True,
6670
dec_hook: dec_hook_sig = None,
71+
float_hook: float_hook_sig = None,
6772
) -> None: ...
6873
def decode(self, data: Union[bytes, str]) -> T: ...
6974
def decode_lines(self, data: Union[bytes, str]) -> list[T]: ...

tests/basic_typing_examples.py

+9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import datetime
5+
import decimal
56
import pickle
67
from typing import Any, Dict, Final, List, Type, Union
78

@@ -826,6 +827,14 @@ def dec_hook(typ: Type, obj: Any) -> Any:
826827
msgspec.json.Decoder(dec_hook=dec_hook)
827828

828829

830+
def check_json_Decoder_float_hook() -> None:
831+
msgspec.json.Decoder(float_hook=None)
832+
msgspec.json.Decoder(float_hook=float)
833+
dec = msgspec.json.Decoder(float_hook=decimal.Decimal)
834+
if dec.float_hook is not None:
835+
dec.float_hook("1.5")
836+
837+
829838
def check_json_Decoder_strict() -> None:
830839
dec = msgspec.json.Decoder(List[int], strict=False)
831840
reveal_type(dec.strict) # assert "bool" in typ

tests/test_json.py

+50
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import base64
44
import datetime
5+
import decimal
56
import enum
67
import gc
78
import itertools
@@ -524,6 +525,19 @@ def test_decode_lines_bad_call(self):
524525
with pytest.raises(TypeError):
525526
dec.decode(1)
526527

528+
def test_decoder_init_float_hook(self):
529+
dec = msgspec.json.Decoder()
530+
assert dec.float_hook is None
531+
532+
dec = msgspec.json.Decoder(float_hook=None)
533+
assert dec.float_hook is None
534+
535+
dec = msgspec.json.Decoder(float_hook=decimal.Decimal)
536+
assert dec.float_hook is decimal.Decimal
537+
538+
with pytest.raises(TypeError):
539+
dec = msgspec.json.Decoder(float_hook=1)
540+
527541

528542
class TestBoolAndNone:
529543
def test_encode_none(self):
@@ -1567,6 +1581,42 @@ def test_decode_float_err_expected_int(self, s):
15671581
):
15681582
msgspec.json.decode(s, type=int)
15691583

1584+
def test_float_hook_untyped(self):
1585+
dec = msgspec.json.Decoder(float_hook=decimal.Decimal)
1586+
res = dec.decode(b"1.33")
1587+
assert res == decimal.Decimal("1.33")
1588+
assert type(res) is decimal.Decimal
1589+
1590+
def test_float_hook_typed(self):
1591+
class Ex(msgspec.Struct):
1592+
a: float
1593+
b: decimal.Decimal
1594+
c: Any
1595+
d: Any
1596+
1597+
class MyFloat(NamedTuple):
1598+
x: str
1599+
1600+
dec = msgspec.json.Decoder(Ex, float_hook=MyFloat)
1601+
res = dec.decode(b'{"a": 1.5, "b": 1.3, "c": 1.3, "d": 123}')
1602+
sol = Ex(1.5, decimal.Decimal("1.3"), MyFloat("1.3"), 123)
1603+
assert res == sol
1604+
1605+
def test_float_hook_error(self):
1606+
def float_hook(val):
1607+
raise ValueError("Oh no!")
1608+
1609+
class Ex(msgspec.Struct):
1610+
a: float
1611+
b: Any
1612+
1613+
dec = msgspec.json.Decoder(Ex, float_hook=float_hook)
1614+
assert dec.decode(b'{"a": 1.5, "b": 2}') == Ex(a=1.5, b=2)
1615+
with pytest.raises(msgspec.ValidationError) as rec:
1616+
dec.decode(b'{"a": 1.5, "b": 2.5}')
1617+
assert "Oh no!" in str(rec.value)
1618+
assert "at `$.b`" in str(rec.value)
1619+
15701620

15711621
class TestDecimal:
15721622
"""Most decimal tests are in test_common.py, the ones here are for json

0 commit comments

Comments
 (0)