Skip to content

Commit f02e982

Browse files
authored
INTPYTHON-389 - Avoid Type Mismatch Conversion to NaN (#316)
1 parent bf5af52 commit f02e982

File tree

7 files changed

+500
-105
lines changed

7 files changed

+500
-105
lines changed

bindings/python/CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33

44
---
55

6+
# Changes in Version 1.9.0 (2025/XX/XX)
7+
8+
- Providing a schema now enforces strict type adherence for data.
9+
If a result contains a field whose value does not match the schema's type for that field, a TypeError will be raised.
10+
Note that ``NaN`` is a valid type for all fields.
11+
To suppress these errors and instead silently convert such mismatches to ``NaN``, pass the ``allow_invalid=True`` argument to your ``pymongoarrow`` API call.
12+
For example, a result with a field of type ``int`` but with a string value will now raise a TypeError,
13+
unless ``allow_invalid=True`` is passed, in which case the result's field will have a value of ``NaN``.
14+
615
# Changes in Version 1.8.0 (2025/05/12)
716

817
- Add support for PyArrow 20.0.

bindings/python/pymongoarrow/api.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
_MAX_WRITE_BATCH_SIZE = max(100000, MAX_WRITE_BATCH_SIZE)
7373

7474

75-
def find_arrow_all(collection, query, *, schema=None, **kwargs):
75+
def find_arrow_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
7676
"""Method that returns the results of a find query as a
7777
:class:`pyarrow.Table` instance.
7878
@@ -83,14 +83,18 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs):
8383
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
8484
If the schema is not given, it will be inferred using the data in the
8585
result set.
86+
- `allow_invalid` (optional): If set to ``True``,
87+
results will have all fields that do not conform to the schema silently converted to NaN.
8688
8789
Additional keyword-arguments passed to this method will be passed
8890
directly to the underlying ``find`` operation.
8991
9092
:Returns:
9193
An instance of class:`pyarrow.Table`.
9294
"""
93-
context = PyMongoArrowContext(schema, codec_options=collection.codec_options)
95+
context = PyMongoArrowContext(
96+
schema, codec_options=collection.codec_options, allow_invalid=allow_invalid
97+
)
9498

9599
for opt in ("cursor_type",):
96100
if kwargs.pop(opt, None):
@@ -110,7 +114,7 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs):
110114
return context.finish()
111115

112116

113-
def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
117+
def aggregate_arrow_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
114118
"""Method that returns the results of an aggregation pipeline as a
115119
:class:`pyarrow.Table` instance.
116120
@@ -121,14 +125,18 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs):
121125
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
122126
If the schema is not given, it will be inferred using the data in the
123127
result set.
128+
- `allow_invalid` (optional): If set to ``True``,
129+
results will have all fields that do not conform to the schema silently converted to NaN.
124130
125131
Additional keyword-arguments passed to this method will be passed
126132
directly to the underlying ``aggregate`` operation.
127133
128134
:Returns:
129135
An instance of class:`pyarrow.Table`.
130136
"""
131-
context = PyMongoArrowContext(schema, codec_options=collection.codec_options)
137+
context = PyMongoArrowContext(
138+
schema, codec_options=collection.codec_options, allow_invalid=allow_invalid
139+
)
132140

133141
if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
134142
msg = (
@@ -165,7 +173,7 @@ def _arrow_to_pandas(arrow_table):
165173
return arrow_table.to_pandas(split_blocks=True, self_destruct=True)
166174

167175

168-
def find_pandas_all(collection, query, *, schema=None, **kwargs):
176+
def find_pandas_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
169177
"""Method that returns the results of a find query as a
170178
:class:`pandas.DataFrame` instance.
171179
@@ -176,17 +184,21 @@ def find_pandas_all(collection, query, *, schema=None, **kwargs):
176184
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
177185
If the schema is not given, it will be inferred using the data in the
178186
result set.
187+
- `allow_invalid` (optional): If set to ``True``,
188+
results will have all fields that do not conform to the schema silently converted to NaN.
179189
180190
Additional keyword-arguments passed to this method will be passed
181191
directly to the underlying ``find`` operation.
182192
183193
:Returns:
184194
An instance of class:`pandas.DataFrame`.
185195
"""
186-
return _arrow_to_pandas(find_arrow_all(collection, query, schema=schema, **kwargs))
196+
return _arrow_to_pandas(
197+
find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs)
198+
)
187199

188200

189-
def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs):
201+
def aggregate_pandas_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
190202
"""Method that returns the results of an aggregation pipeline as a
191203
:class:`pandas.DataFrame` instance.
192204
@@ -197,14 +209,20 @@ def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs):
197209
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
198210
If the schema is not given, it will be inferred using the data in the
199211
result set.
212+
- `allow_invalid` (optional): If set to ``True``,
213+
results will have all fields that do not conform to the schema silently converted to NaN.
200214
201215
Additional keyword-arguments passed to this method will be passed
202216
directly to the underlying ``aggregate`` operation.
203217
204218
:Returns:
205219
An instance of class:`pandas.DataFrame`.
206220
"""
207-
return _arrow_to_pandas(aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs))
221+
return _arrow_to_pandas(
222+
aggregate_arrow_all(
223+
collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs
224+
)
225+
)
208226

209227

210228
def _arrow_to_numpy(arrow_table, schema=None):
@@ -227,7 +245,7 @@ def _arrow_to_numpy(arrow_table, schema=None):
227245
return container
228246

229247

230-
def find_numpy_all(collection, query, *, schema=None, **kwargs):
248+
def find_numpy_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
231249
"""Method that returns the results of a find query as a
232250
:class:`dict` instance whose keys are field names and values are
233251
:class:`~numpy.ndarray` instances bearing the appropriate dtype.
@@ -239,6 +257,8 @@ def find_numpy_all(collection, query, *, schema=None, **kwargs):
239257
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
240258
If the schema is not given, it will be inferred using the data in the
241259
result set.
260+
- `allow_invalid` (optional): If set to ``True``,
261+
results will have all fields that do not conform to the schema silently converted to NaN.
242262
243263
Additional keyword-arguments passed to this method will be passed
244264
directly to the underlying ``find`` operation.
@@ -255,10 +275,13 @@ def find_numpy_all(collection, query, *, schema=None, **kwargs):
255275
:Returns:
256276
An instance of :class:`dict`.
257277
"""
258-
return _arrow_to_numpy(find_arrow_all(collection, query, schema=schema, **kwargs), schema)
278+
return _arrow_to_numpy(
279+
find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs),
280+
schema,
281+
)
259282

260283

261-
def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
284+
def aggregate_numpy_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
262285
"""Method that returns the results of an aggregation pipeline as a
263286
:class:`dict` instance whose keys are field names and values are
264287
:class:`~numpy.ndarray` instances bearing the appropriate dtype.
@@ -270,6 +293,8 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
270293
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
271294
If the schema is not given, it will be inferred using the data in the
272295
result set.
296+
- `allow_invalid` (optional): If set to ``True``,
297+
results will have all fields that do not conform to the schema silently converted to NaN.
273298
274299
Additional keyword-arguments passed to this method will be passed
275300
directly to the underlying ``aggregate`` operation.
@@ -287,7 +312,10 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
287312
An instance of :class:`dict`.
288313
"""
289314
return _arrow_to_numpy(
290-
aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs), schema
315+
aggregate_arrow_all(
316+
collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs
317+
),
318+
schema,
291319
)
292320

293321

@@ -326,7 +354,7 @@ def _arrow_to_polars(arrow_table: pa.Table):
326354
return pl.from_arrow(arrow_table_without_extensions)
327355

328356

329-
def find_polars_all(collection, query, *, schema=None, **kwargs):
357+
def find_polars_all(collection, query, *, schema=None, allow_invalid=False, **kwargs):
330358
"""Method that returns the results of a find query as a
331359
:class:`polars.DataFrame` instance.
332360
@@ -337,6 +365,8 @@ def find_polars_all(collection, query, *, schema=None, **kwargs):
337365
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
338366
If the schema is not given, it will be inferred using the data in the
339367
result set.
368+
- `allow_invalid` (optional): If set to ``True``,
369+
results will have all fields that do not conform to the schema silently converted to NaN.
340370
341371
Additional keyword-arguments passed to this method will be passed
342372
directly to the underlying ``find`` operation.
@@ -346,10 +376,12 @@ def find_polars_all(collection, query, *, schema=None, **kwargs):
346376
347377
.. versionadded:: 1.3
348378
"""
349-
return _arrow_to_polars(find_arrow_all(collection, query, schema=schema, **kwargs))
379+
return _arrow_to_polars(
380+
find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs)
381+
)
350382

351383

352-
def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs):
384+
def aggregate_polars_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs):
353385
"""Method that returns the results of an aggregation pipeline as a
354386
:class:`polars.DataFrame` instance.
355387
@@ -360,14 +392,20 @@ def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs):
360392
- `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`.
361393
If the schema is not given, it will be inferred using the data in the
362394
result set.
395+
- `allow_invalid` (optional): If set to ``True``,
396+
results will have all fields that do not conform to the schema silently converted to NaN.
363397
364398
Additional keyword-arguments passed to this method will be passed
365399
directly to the underlying ``aggregate`` operation.
366400
367401
:Returns:
368402
An instance of class:`polars.DataFrame`.
369403
"""
370-
return _arrow_to_polars(aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs))
404+
return _arrow_to_polars(
405+
aggregate_arrow_all(
406+
collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs
407+
)
408+
)
371409

372410

373411
def _transform_bwe(bwe, offset):

bindings/python/pymongoarrow/context.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
class PyMongoArrowContext:
2020
"""A context for converting BSON-formatted data to an Arrow Table."""
2121

22-
def __init__(self, schema, codec_options=None):
22+
def __init__(self, schema, codec_options=None, allow_invalid=False):
2323
"""Initialize the context.
2424
2525
:Parameters:
2626
- `schema`: Instance of :class:`~pymongoarrow.schema.Schema`.
2727
- `builder_map`: Mapping of utf-8-encoded field names to
2828
:class:`~pymongoarrow.builders._BuilderBase` instances.
29+
- `allow_invalid`: If set to ``True``,
30+
results will have all fields that do not conform to the schema silently converted to NaN.
2931
"""
3032
self.schema = schema
3133
if self.schema is None and codec_options is not None:
@@ -40,7 +42,9 @@ def __init__(self, schema, codec_options=None):
4042
# Delayed import to prevent import errors for unbuilt library.
4143
from pymongoarrow.lib import BuilderManager
4244

43-
self.manager = BuilderManager(schema_map, self.schema is not None, self.tzinfo)
45+
self.manager = BuilderManager(
46+
schema_map, self.schema is not None, self.tzinfo, allow_invalid=allow_invalid
47+
)
4448
self.schema_map = schema_map
4549

4650
def process_bson_stream(self, stream):

0 commit comments

Comments
 (0)