@@ -4621,7 +4621,7 @@ cdef class FixedShapeTensorArray(ExtensionArray):
46214621 and the rest of the dimensions will match the permuted shape of the fixed
46224622 shape tensor.
46234623
4624- The conversion is zero-copy.
4624+ The conversion is zero-copy if data is primitive numeric and without nulls .
46254625
46264626 Returns
46274627 -------
@@ -4874,6 +4874,137 @@ cdef class Bool8Array(ExtensionArray):
48744874 return Bool8Array.from_storage(storage_arr)
48754875
48764876
4877+ cdef class VariableShapeTensorArray(ExtensionArray):
4878+ """
4879+ Concrete class for variable shape tensor extension arrays.
4880+
4881+ Examples
4882+ --------
4883+ Define the extension type for tensor array
4884+
4885+ >>> import pyarrow as pa
4886+ >>> tensor_type = pa.variable_shape_tensor(pa.float64(), 2)
4887+
4888+ Create an extension array
4889+
4890+ >>> shapes = pa.array([[2, 3], [1, 2]], pa.list_(pa.int32(), 2))
4891+ >>> values = pa.array([[1, 2, 3, 4, 5, 6], [7, 8]], pa.list_(pa.float64()))
4892+ >>> arr = pa.StructArray.from_arrays([values, shapes], names=["data", "shape"])
4893+ >>> pa.ExtensionArray.from_storage(tensor_type, arr)
4894+ <pyarrow.lib.VariableShapeTensorArray object at ...>
4895+ -- is_valid: all not null
4896+ -- child 0 type: list<item: double>
4897+ [
4898+ [
4899+ 1,
4900+ 2,
4901+ 3,
4902+ 4,
4903+ 5,
4904+ 6
4905+ ],
4906+ [
4907+ 7,
4908+ 8
4909+ ]
4910+ ]
4911+ -- child 1 type: fixed_size_list<item: int32>[2]
4912+ [
4913+ [
4914+ 2,
4915+ 3
4916+ ],
4917+ [
4918+ 1,
4919+ 2
4920+ ]
4921+ ]
4922+ """
4923+
4924+ @staticmethod
4925+ def from_numpy_ndarray (obj ):
4926+ """
4927+ Convert a list of numpy.ndarrays to a variable shape tensor extension array.
4928+ The length of the input list will become the length of the variable shape tensor array.
4929+
4930+ Parameters
4931+ ----------
4932+ obj : list of numpy.ndarray
4933+
4934+ Examples
4935+ --------
4936+ >>> import pyarrow as pa
4937+ >>> import numpy as np
4938+
4939+ >>> ndarray_list = [
4940+ ... np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
4941+ ... np.array([[7, 8]], dtype=np.float32),
4942+ ... ]
4943+ >>> arr = pa.VariableShapeTensorArray.from_numpy_ndarray(ndarray_list)
4944+ >>> assert len(ndarray_list) == len(arr)
4945+ >>> arr.type
4946+ VariableShapeTensorType(extension<arrow.variable_shape_tensor[value_type=float, ndim=2, permutation=[0,1]]>)
4947+ >>> arr
4948+ <pyarrow.lib.VariableShapeTensorArray object at ...>
4949+ -- is_valid: all not null
4950+ -- child 0 type: list<item: float>
4951+ [
4952+ [
4953+ 1,
4954+ 2,
4955+ 3,
4956+ 4,
4957+ 5,
4958+ 6
4959+ ],
4960+ [
4961+ 7,
4962+ 8
4963+ ]
4964+ ]
4965+ -- child 1 type: fixed_size_list<item: int32>[2]
4966+ [
4967+ [
4968+ 2,
4969+ 3
4970+ ],
4971+ [
4972+ 1,
4973+ 2
4974+ ]
4975+ ]
4976+ """
4977+ assert isinstance (obj, list ), ' obj must be a list of numpy arrays'
4978+ numpy_type = obj[0 ].dtype
4979+ arrow_type = from_numpy_dtype(numpy_type)
4980+ ndim = obj[0 ].ndim
4981+ permutations = [(- np.array(o.strides)).argsort(kind = " stable" ) for o in obj]
4982+ permutation = permutations[0 ]
4983+ shapes = [np.take(o.shape, permutation) for o in obj]
4984+
4985+ if not all ([o.dtype == numpy_type for o in obj]):
4986+ raise TypeError (' All numpy arrays must have matching dtype.' )
4987+
4988+ if not all ([o.ndim == ndim for o in obj]):
4989+ raise ValueError (' All numpy arrays must have matching ndim.' )
4990+
4991+ if not all ([np.array_equal(p, permutation) for p in permutations]):
4992+ raise ValueError (' All numpy arrays must have matching permutation.' )
4993+
4994+ for shape in shapes:
4995+ if len (shape) < 2 :
4996+ raise ValueError (
4997+ " Cannot convert 1D array or scalar to fixed shape tensor array" )
4998+ if np.prod(shape) == 0 :
4999+ raise ValueError (" Expected a non-empty ndarray" )
5000+
5001+ values = array([np.ravel(o, order = " K" ) for o in obj], list_(arrow_type))
5002+ shapes = array(shapes, list_(int32(), list_size = ndim))
5003+ struct_arr = StructArray.from_arrays([values, shapes], names = [" data" , " shape" ])
5004+
5005+ return ExtensionArray.from_storage(variable_shape_tensor(arrow_type, ndim, permutation = permutation), struct_arr)
5006+
5007+
48775008cdef dict _array_classes = {
48785009 _Type_NA: NullArray,
48795010 _Type_BOOL: BooleanArray,
0 commit comments