Skip to content

Commit c2577f1

Browse files
committedJul 15, 2024
Fix cooperation of our load(s) implementation with a user supplied object_hook
1 parent 7a5579f commit c2577f1

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed
 

‎json_numpy.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
import json
77
from base64 import b64decode, b64encode
88
from functools import partial
9-
from typing import Any, Callable
9+
from typing import TYPE_CHECKING, Any, Callable
1010

1111
from numpy import frombuffer, generic, ndarray
1212
from numpy.lib.format import descr_to_dtype, dtype_to_descr
1313

14+
if TYPE_CHECKING: # pragma: no cover
15+
from _typeshed import SupportsRead
16+
1417

1518
def default(
1619
o: Any, *, fallback_default: Callable[[Any], dict[str, Any]] | None = None
@@ -88,19 +91,25 @@ def dumps(*args: Any, cls: type[json.JSONEncoder] | None = None, **kwargs: Any)
8891
return _dumps(*args, cls=_patch_encoder, **kwargs) # type: ignore[arg-type]
8992

9093

91-
def loads(*args: Any, **kwargs: Any) -> Any:
92-
kwargs.setdefault("object_hook", object_hook)
93-
return _loads(*args, **kwargs)
94+
def loads(
95+
*args: Any, object_hook: Callable[[dict], Any] | None = None, **kwargs: Any
96+
) -> Any:
97+
return _loads(
98+
*args,
99+
object_hook=_hook
100+
if object_hook is None
101+
else lambda dct: _hook(object_hook(dct)),
102+
**kwargs,
103+
)
94104

95105

96106
def dump(*args: Any, cls: type[json.JSONEncoder] | None = None, **kwargs: Any) -> None:
97107
kwargs["user_cls"] = cls
98108
return _dump(*args, cls=_patch_encoder, **kwargs) # type: ignore[arg-type]
99109

100110

101-
def load(*args: Any, **kwargs: Any) -> Any:
102-
kwargs.setdefault("object_hook", object_hook)
103-
return _load(*args, **kwargs)
111+
def load(fp: SupportsRead[str | bytes], **kwargs: Any) -> Any:
112+
return loads(fp.read(), **kwargs)
104113

105114

106115
def patch() -> None:

‎tests/test_json_numpy.py

+12
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,15 @@ def default(self, _: Any) -> dict[str, Any]:
221221
x = np.random.rand(5).astype(np.float32)
222222
dumped = json.dumps(x, cls=Encoder)
223223
self.assert_equal_with_type(json.loads(dumped), x)
224+
225+
def test_loads_object_hook(self) -> None:
226+
def hook(dct: dict) -> dict | int:
227+
if "foo" in dct:
228+
return dct["foo"]
229+
return dct
230+
231+
foo = {"foo": "bar"}
232+
x = np.random.rand(5).astype(np.float32)
233+
result = json.loads(json.dumps([foo, x]), object_hook=hook)
234+
self.assertEqual(result[0], "bar")
235+
self.assert_equal_with_type(result[1], x)

0 commit comments

Comments
 (0)
Please sign in to comment.