Skip to content

Commit 7f84762

Browse files
committed
GH-38007: [Python] Add VariableShapeTensor Python bindings
Add PyArrow bindings for the VariableShapeTensor extension type, including VariableShapeTensorType, VariableShapeTensorArray, and VariableShapeTensorScalar with support for converting to/from NumPy tensors.
1 parent 2fcc3ec commit 7f84762

9 files changed

Lines changed: 671 additions & 7 deletions

File tree

docs/source/python/api/arrays.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ may expose data type-specific methods or properties.
101101
JsonArray
102102
UuidArray
103103
Bool8Array
104+
VariableShapeTensorArray
104105

105106
.. _api.scalar:
106107

python/pyarrow/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def print_entry(label, value):
165165
dictionary,
166166
run_end_encoded,
167167
bool8, fixed_shape_tensor, json_, opaque, uuid,
168+
variable_shape_tensor,
168169
field,
169170
type_for_alias,
170171
DataType, DictionaryType, StructType,
@@ -178,6 +179,7 @@ def print_entry(label, value):
178179
RunEndEncodedType, Bool8Type, FixedShapeTensorType,
179180
JsonType, OpaqueType, UuidType,
180181
UnknownExtensionType,
182+
VariableShapeTensorType,
181183
register_extension_type, unregister_extension_type,
182184
DictionaryMemo,
183185
KeyValueMetadata,
@@ -214,6 +216,7 @@ def print_entry(label, value):
214216
StructArray, ExtensionArray,
215217
RunEndEncodedArray, Bool8Array, FixedShapeTensorArray,
216218
JsonArray, OpaqueArray, UuidArray,
219+
VariableShapeTensorArray,
217220
scalar, NA, _NULL as NULL, Scalar,
218221
NullScalar, BooleanScalar,
219222
Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,

python/pyarrow/array.pxi

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4621,7 +4621,7 @@ cdef class FixedShapeTensorArray(ExtensionArray):
46214621
and the rest of the dimensions will match the permuted shape of the fixed
46224622
shape tensor.
46234623
4624-
The conversion is zero-copy.
4624+
The conversion is zero-copy if data is primitive numeric and without nulls.
46254625
46264626
Returns
46274627
-------
@@ -4874,6 +4874,137 @@ cdef class Bool8Array(ExtensionArray):
48744874
return Bool8Array.from_storage(storage_arr)
48754875

48764876

4877+
cdef class VariableShapeTensorArray(ExtensionArray):
4878+
"""
4879+
Concrete class for variable shape tensor extension arrays.
4880+
4881+
Examples
4882+
--------
4883+
Define the extension type for tensor array
4884+
4885+
>>> import pyarrow as pa
4886+
>>> tensor_type = pa.variable_shape_tensor(pa.float64(), 2)
4887+
4888+
Create an extension array
4889+
4890+
>>> shapes = pa.array([[2, 3], [1, 2]], pa.list_(pa.int32(), 2))
4891+
>>> values = pa.array([[1, 2, 3, 4, 5, 6], [7, 8]], pa.list_(pa.float64()))
4892+
>>> arr = pa.StructArray.from_arrays([values, shapes], names=["data", "shape"])
4893+
>>> pa.ExtensionArray.from_storage(tensor_type, arr)
4894+
<pyarrow.lib.VariableShapeTensorArray object at ...>
4895+
-- is_valid: all not null
4896+
-- child 0 type: list<item: double>
4897+
[
4898+
[
4899+
1,
4900+
2,
4901+
3,
4902+
4,
4903+
5,
4904+
6
4905+
],
4906+
[
4907+
7,
4908+
8
4909+
]
4910+
]
4911+
-- child 1 type: fixed_size_list<item: int32>[2]
4912+
[
4913+
[
4914+
2,
4915+
3
4916+
],
4917+
[
4918+
1,
4919+
2
4920+
]
4921+
]
4922+
"""
4923+
4924+
@staticmethod
4925+
def from_numpy_ndarray(obj):
4926+
"""
4927+
Convert a list of numpy.ndarrays to a variable shape tensor extension array.
4928+
The length of the input list will become the length of the variable shape tensor array.
4929+
4930+
Parameters
4931+
----------
4932+
obj : list of numpy.ndarray
4933+
4934+
Examples
4935+
--------
4936+
>>> import pyarrow as pa
4937+
>>> import numpy as np
4938+
4939+
>>> ndarray_list = [
4940+
... np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
4941+
... np.array([[7, 8]], dtype=np.float32),
4942+
... ]
4943+
>>> arr = pa.VariableShapeTensorArray.from_numpy_ndarray(ndarray_list)
4944+
>>> assert len(ndarray_list) == len(arr)
4945+
>>> arr.type
4946+
VariableShapeTensorType(extension<arrow.variable_shape_tensor[value_type=float, ndim=2, permutation=[0,1]]>)
4947+
>>> arr
4948+
<pyarrow.lib.VariableShapeTensorArray object at ...>
4949+
-- is_valid: all not null
4950+
-- child 0 type: list<item: float>
4951+
[
4952+
[
4953+
1,
4954+
2,
4955+
3,
4956+
4,
4957+
5,
4958+
6
4959+
],
4960+
[
4961+
7,
4962+
8
4963+
]
4964+
]
4965+
-- child 1 type: fixed_size_list<item: int32>[2]
4966+
[
4967+
[
4968+
2,
4969+
3
4970+
],
4971+
[
4972+
1,
4973+
2
4974+
]
4975+
]
4976+
"""
4977+
assert isinstance(obj, list), 'obj must be a list of numpy arrays'
4978+
numpy_type = obj[0].dtype
4979+
arrow_type = from_numpy_dtype(numpy_type)
4980+
ndim = obj[0].ndim
4981+
permutations = [(-np.array(o.strides)).argsort(kind="stable") for o in obj]
4982+
permutation = permutations[0]
4983+
shapes = [np.take(o.shape, permutation) for o in obj]
4984+
4985+
if not all([o.dtype == numpy_type for o in obj]):
4986+
raise TypeError('All numpy arrays must have matching dtype.')
4987+
4988+
if not all([o.ndim == ndim for o in obj]):
4989+
raise ValueError('All numpy arrays must have matching ndim.')
4990+
4991+
if not all([np.array_equal(p, permutation) for p in permutations]):
4992+
raise ValueError('All numpy arrays must have matching permutation.')
4993+
4994+
for shape in shapes:
4995+
if len(shape) < 2:
4996+
raise ValueError(
4997+
"Cannot convert 1D array or scalar to fixed shape tensor array")
4998+
if np.prod(shape) == 0:
4999+
raise ValueError("Expected a non-empty ndarray")
5000+
5001+
values = array([np.ravel(o, order="K") for o in obj], list_(arrow_type))
5002+
shapes = array(shapes, list_(int32(), list_size=ndim))
5003+
struct_arr = StructArray.from_arrays([values, shapes], names=["data", "shape"])
5004+
5005+
return ExtensionArray.from_storage(variable_shape_tensor(arrow_type, ndim, permutation=permutation), struct_arr)
5006+
5007+
48775008
cdef dict _array_classes = {
48785009
_Type_NA: NullArray,
48795010
_Type_BOOL: BooleanArray,

python/pyarrow/includes/libarrow.pxd

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,14 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
908908
const shared_ptr[CBuffer] null_bitmap,
909909
)
910910

911+
@staticmethod
912+
CResult[shared_ptr[CArray]] FromArraysAndType" FromArrays"(
913+
shared_ptr[CDataType],
914+
const shared_ptr[CArray]& offsets,
915+
const shared_ptr[CArray]& keys,
916+
const shared_ptr[CArray]& items,
917+
CMemoryPool* pool)
918+
911919
shared_ptr[CArray] keys()
912920
shared_ptr[CArray] items()
913921
CMapType* map_type()
@@ -1184,6 +1192,11 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
11841192
void set_chunksize(int64_t chunksize)
11851193

11861194
cdef cppclass CTensor" arrow::Tensor":
1195+
CTensor(const shared_ptr[CDataType]& type,
1196+
const shared_ptr[CBuffer]& data,
1197+
const vector[int64_t]& shape,
1198+
const vector[int64_t]& strides,
1199+
const vector[c_string]& dim_names)
11871200
shared_ptr[CDataType] type()
11881201
shared_ptr[CBuffer] data()
11891202

@@ -3014,6 +3027,24 @@ cdef extern from "arrow/extension_type.h" namespace "arrow":
30143027

30153028
shared_ptr[CArray] storage()
30163029

3030+
cdef extern from "arrow/extension/variable_shape_tensor.h" namespace "arrow::extension" nogil:
3031+
cdef cppclass CVariableShapeTensorType \
3032+
" arrow::extension::VariableShapeTensorType"(CExtensionType):
3033+
3034+
CResult[shared_ptr[CTensor]] MakeTensor(const shared_ptr[CExtensionScalar]& scalar) const
3035+
3036+
@staticmethod
3037+
CResult[shared_ptr[CDataType]] Make(const shared_ptr[CDataType]& value_type,
3038+
const int32_t ndim,
3039+
const vector[int64_t] permutation,
3040+
const vector[c_string] dim_names,
3041+
const vector[optional[int64_t]] uniform_shape)
3042+
3043+
const shared_ptr[CDataType] value_type()
3044+
const int32_t ndim()
3045+
const vector[int64_t] permutation()
3046+
const vector[c_string] dim_names()
3047+
const vector[optional[int64_t]] uniform_shape()
30173048

30183049
cdef extern from "arrow/extension/json.h" namespace "arrow::extension" nogil:
30193050
cdef cppclass CJsonType" arrow::extension::JsonExtensionType"(CExtensionType):
@@ -3034,7 +3065,7 @@ cdef extern from "arrow/extension/uuid.h" namespace "arrow::extension" nogil:
30343065

30353066
cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extension" nogil:
30363067
cdef cppclass CFixedShapeTensorType \
3037-
" arrow::extension::FixedShapeTensorType"(CExtensionType):
3068+
" arrow::extension::FixedShapeTensorType"(CExtensionType) nogil:
30383069

30393070
CResult[shared_ptr[CTensor]] MakeTensor(const shared_ptr[CExtensionScalar]& scalar) const
30403071

python/pyarrow/lib.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ cdef class ExtensionType(BaseExtensionType):
195195
const CPyExtensionType* cpy_ext_type
196196

197197

198+
cdef class VariableShapeTensorType(BaseExtensionType):
199+
cdef:
200+
const CVariableShapeTensorType* tensor_ext_type
201+
202+
198203
cdef class FixedShapeTensorType(BaseExtensionType):
199204
cdef:
200205
const CFixedShapeTensorType* tensor_ext_type

python/pyarrow/public-api.pxi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ cdef api object pyarrow_wrap_data_type(
131131
out = Bool8Type.__new__(Bool8Type)
132132
elif extension_name == b"arrow.fixed_shape_tensor":
133133
out = FixedShapeTensorType.__new__(FixedShapeTensorType)
134+
elif extension_name == b"arrow.variable_shape_tensor":
135+
out = VariableShapeTensorType.__new__(VariableShapeTensorType)
134136
elif extension_name == b"arrow.opaque":
135137
out = OpaqueType.__new__(OpaqueType)
136138
elif extension_name == b"arrow.uuid":

python/pyarrow/scalar.pxi

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1484,7 +1484,7 @@ cdef class FixedShapeTensorScalar(ExtensionScalar):
14841484
14851485
The resulting ndarray's shape matches the permuted shape of the
14861486
fixed shape tensor scalar.
1487-
The conversion is zero-copy.
1487+
The conversion is zero-copy if data is primitive numeric and without nulls.
14881488
14891489
Returns
14901490
-------
@@ -1539,6 +1539,43 @@ cdef class Bool8Scalar(ExtensionScalar):
15391539
py_val = super().as_py()
15401540
return None if py_val is None else py_val != 0
15411541

1542+
1543+
cdef class VariableShapeTensorScalar(ExtensionScalar):
1544+
"""
1545+
Concrete class for variable shape tensor extension scalar.
1546+
"""
1547+
1548+
def to_numpy_ndarray(self):
1549+
"""
1550+
Convert variable shape tensor extension scalar to a numpy array.
1551+
1552+
The conversion is zero-copy if data is primitive numeric and without nulls.
1553+
1554+
Returns
1555+
-------
1556+
numpy.ndarray
1557+
"""
1558+
return self.to_tensor().to_numpy()
1559+
1560+
def to_tensor(self):
1561+
"""
1562+
Convert variable shape tensor extension scalar to a pyarrow.Tensor.
1563+
1564+
Returns
1565+
-------
1566+
tensor : pyarrow.Tensor
1567+
"""
1568+
cdef:
1569+
CVariableShapeTensorType* c_type = static_pointer_cast[CVariableShapeTensorType, CDataType](
1570+
self.wrapped.get().type).get()
1571+
shared_ptr[CExtensionScalar] scalar = static_pointer_cast[CExtensionScalar, CScalar](self.wrapped)
1572+
shared_ptr[CTensor] ctensor
1573+
1574+
with nogil:
1575+
ctensor = GetResultValue(c_type.MakeTensor(scalar))
1576+
return pyarrow_wrap_tensor(ctensor)
1577+
1578+
15421579
cdef dict _scalar_classes = {
15431580
_Type_BOOL: BooleanScalar,
15441581
_Type_UINT8: UInt8Scalar,

0 commit comments

Comments
 (0)