Skip to content

Commit 30cc5dd

Browse files
thisisnicAlenkaFpitrou
authored
GH-45653: [Python] Scalar subclasses should implement Python protocols (#45818)
### Rationale for this change Implement dunder methods on Scalar objects ### What changes are included in this PR? * integer scalars implement `__int__` * floating-point scalars implement `__float__` * binary scalars implement [`__bytes__`](https://docs.python.org/3.13/reference/datamodel.html#object.__bytes__) * binary scalars implement the [buffer protocol](https://docs.python.org/3.13/reference/datamodel.html#object.__buffer__) * we explicitly test that Struct scalars implement Sequences * Map scalar implement mapping ### Are these changes tested? Yes ### Are there any user-facing changes? Yes * GitHub Issue: #45653 Lead-authored-by: Nic Crane <thisisnic@gmail.com> Co-authored-by: Alenka Frim <AlenkaF@users.noreply.github.com> Co-authored-by: Antoine Pitrou <pitrou@free.fr> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 5425175 commit 30cc5dd

3 files changed

Lines changed: 144 additions & 9 deletions

File tree

docs/source/python/compute.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,21 @@ Below are a few simple examples::
6363
>>> pc.multiply(x, y)
6464
<pyarrow.DoubleScalar: 72.54>
6565

66+
If you are using a compute function which returns more than one value, results
67+
will be returned as a ``StructScalar``. You can extract the individual values by
68+
calling the :meth:`pyarrow.StructScalar.values` method::
69+
70+
>>> import pyarrow as pa
71+
>>> import pyarrow.compute as pc
72+
>>> a = pa.array([1, 1, 2, 3])
73+
>>> pc.min_max(a)
74+
<pyarrow.StructScalar: [('min', 1), ('max', 3)]>
75+
>>> a, b = pc.min_max(a).values()
76+
>>> a
77+
<pyarrow.Int64Scalar: 1>
78+
>>> b
79+
<pyarrow.Int64Scalar: 3>
80+
6681
These functions can do more than just element-by-element operations.
6782
Here is an example of sorting a table::
6883

python/pyarrow/scalar.pxi

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import collections
1919
import warnings
2020
from uuid import UUID
21+
from collections.abc import Sequence, Mapping
2122

2223

2324
cdef class Scalar(_Weakrefable):
@@ -219,6 +220,8 @@ cdef class BooleanScalar(Scalar):
219220
cdef CBooleanScalar* sp = <CBooleanScalar*> self.wrapped.get()
220221
return sp.value if sp.is_valid else None
221222

223+
def __bool__(self):
224+
return self.as_py() or False
222225

223226
cdef class UInt8Scalar(Scalar):
224227
"""
@@ -238,6 +241,9 @@ cdef class UInt8Scalar(Scalar):
238241
cdef CUInt8Scalar* sp = <CUInt8Scalar*> self.wrapped.get()
239242
return sp.value if sp.is_valid else None
240243

244+
def __index__(self):
245+
return self.as_py()
246+
241247

242248
cdef class Int8Scalar(Scalar):
243249
"""
@@ -257,6 +263,9 @@ cdef class Int8Scalar(Scalar):
257263
cdef CInt8Scalar* sp = <CInt8Scalar*> self.wrapped.get()
258264
return sp.value if sp.is_valid else None
259265

266+
def __index__(self):
267+
return self.as_py()
268+
260269

261270
cdef class UInt16Scalar(Scalar):
262271
"""
@@ -276,6 +285,9 @@ cdef class UInt16Scalar(Scalar):
276285
cdef CUInt16Scalar* sp = <CUInt16Scalar*> self.wrapped.get()
277286
return sp.value if sp.is_valid else None
278287

288+
def __index__(self):
289+
return self.as_py()
290+
279291

280292
cdef class Int16Scalar(Scalar):
281293
"""
@@ -295,6 +307,9 @@ cdef class Int16Scalar(Scalar):
295307
cdef CInt16Scalar* sp = <CInt16Scalar*> self.wrapped.get()
296308
return sp.value if sp.is_valid else None
297309

310+
def __index__(self):
311+
return self.as_py()
312+
298313

299314
cdef class UInt32Scalar(Scalar):
300315
"""
@@ -314,6 +329,9 @@ cdef class UInt32Scalar(Scalar):
314329
cdef CUInt32Scalar* sp = <CUInt32Scalar*> self.wrapped.get()
315330
return sp.value if sp.is_valid else None
316331

332+
def __index__(self):
333+
return self.as_py()
334+
317335

318336
cdef class Int32Scalar(Scalar):
319337
"""
@@ -333,6 +351,9 @@ cdef class Int32Scalar(Scalar):
333351
cdef CInt32Scalar* sp = <CInt32Scalar*> self.wrapped.get()
334352
return sp.value if sp.is_valid else None
335353

354+
def __index__(self):
355+
return self.as_py()
356+
336357

337358
cdef class UInt64Scalar(Scalar):
338359
"""
@@ -352,6 +373,9 @@ cdef class UInt64Scalar(Scalar):
352373
cdef CUInt64Scalar* sp = <CUInt64Scalar*> self.wrapped.get()
353374
return sp.value if sp.is_valid else None
354375

376+
def __index__(self):
377+
return self.as_py()
378+
355379

356380
cdef class Int64Scalar(Scalar):
357381
"""
@@ -371,6 +395,9 @@ cdef class Int64Scalar(Scalar):
371395
cdef CInt64Scalar* sp = <CInt64Scalar*> self.wrapped.get()
372396
return sp.value if sp.is_valid else None
373397

398+
def __index__(self):
399+
return self.as_py()
400+
374401

375402
cdef class HalfFloatScalar(Scalar):
376403
"""
@@ -390,6 +417,12 @@ cdef class HalfFloatScalar(Scalar):
390417
cdef CHalfFloatScalar* sp = <CHalfFloatScalar*> self.wrapped.get()
391418
return PyFloat_FromHalf(sp.value) if sp.is_valid else None
392419

420+
def __float__(self):
421+
return self.as_py()
422+
423+
def __int__(self):
424+
return int(self.as_py())
425+
393426

394427
cdef class FloatScalar(Scalar):
395428
"""
@@ -409,6 +442,12 @@ cdef class FloatScalar(Scalar):
409442
cdef CFloatScalar* sp = <CFloatScalar*> self.wrapped.get()
410443
return sp.value if sp.is_valid else None
411444

445+
def __float__(self):
446+
return self.as_py()
447+
448+
def __int__(self):
449+
return int(float(self))
450+
412451

413452
cdef class DoubleScalar(Scalar):
414453
"""
@@ -428,6 +467,12 @@ cdef class DoubleScalar(Scalar):
428467
cdef CDoubleScalar* sp = <CDoubleScalar*> self.wrapped.get()
429468
return sp.value if sp.is_valid else None
430469

470+
def __float__(self):
471+
return self.as_py()
472+
473+
def __int__(self):
474+
return int(float(self))
475+
431476

432477
cdef class Decimal32Scalar(Scalar):
433478
"""
@@ -843,6 +888,15 @@ cdef class BinaryScalar(Scalar):
843888
buffer = self.as_buffer()
844889
return None if buffer is None else buffer.to_pybytes()
845890

891+
def __bytes__(self):
892+
return self.as_py()
893+
894+
def __getbuffer__(self, cp.Py_buffer* buffer, int flags):
895+
buf = self.as_buffer()
896+
if buf is None:
897+
raise ValueError("Cannot export buffer from null Arrow Scalar")
898+
cp.PyObject_GetBuffer(buf, buffer, flags)
899+
846900

847901
cdef class LargeBinaryScalar(BinaryScalar):
848902
pass
@@ -883,7 +937,7 @@ cdef class StringViewScalar(StringScalar):
883937
pass
884938

885939

886-
cdef class ListScalar(Scalar):
940+
cdef class ListScalar(Scalar, Sequence):
887941
"""
888942
Concrete class for list-like scalars.
889943
"""
@@ -952,7 +1006,7 @@ cdef class LargeListViewScalar(ListScalar):
9521006
pass
9531007

9541008

955-
cdef class StructScalar(Scalar, collections.abc.Mapping):
1009+
cdef class StructScalar(Scalar, Mapping):
9561010
"""
9571011
Concrete class for struct scalars.
9581012
"""
@@ -1051,20 +1105,34 @@ cdef class StructScalar(Scalar, collections.abc.Mapping):
10511105
return str(self._as_py_tuple())
10521106

10531107

1054-
cdef class MapScalar(ListScalar):
1108+
cdef class MapScalar(ListScalar, Mapping):
10551109
"""
10561110
Concrete class for map scalars.
10571111
"""
10581112

10591113
def __getitem__(self, i):
10601114
"""
1061-
Return the value at the given index.
1115+
Return the value at the given index or key.
10621116
"""
1117+
10631118
arr = self.values
10641119
if arr is None:
1065-
raise IndexError(i)
1120+
raise IndexError(i) if isinstance(i, int) else KeyError(i)
1121+
1122+
key_field = self.type.key_field.name
1123+
item_field = self.type.item_field.name
1124+
1125+
if isinstance(i, (bytes, str)):
1126+
try:
1127+
key_index = list(self.keys()).index(i)
1128+
except ValueError:
1129+
raise KeyError(i)
1130+
1131+
dct = arr[_normalize_index(key_index, len(arr))]
1132+
return dct[item_field]
1133+
10661134
dct = arr[_normalize_index(i, len(arr))]
1067-
return (dct[self.type.key_field.name], dct[self.type.item_field.name])
1135+
return (dct[key_field], dct[item_field])
10681136

10691137
def __iter__(self):
10701138
"""
@@ -1118,6 +1186,16 @@ cdef class MapScalar(ListScalar):
11181186
result_dict[key] = value
11191187
return result_dict
11201188

1189+
def keys(self):
1190+
"""
1191+
Return the keys of the map as a list.
1192+
"""
1193+
arr = self.values
1194+
if arr is None:
1195+
return []
1196+
key_field = self.type.key_field.name
1197+
return [k.as_py() for k in arr.field(key_field)]
1198+
11211199

11221200
cdef class DictionaryScalar(Scalar):
11231201
"""

0 commit comments

Comments
 (0)