Skip to content

Commit 766f00d

Browse files
committed
misc fixes
1 parent 7f84762 commit 766f00d

4 files changed

Lines changed: 19 additions & 14 deletions

File tree

python/pyarrow/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def print_entry(label, value):
234234
FixedSizeBinaryScalar, DictionaryScalar,
235235
MapScalar, StructScalar, UnionScalar,
236236
RunEndEncodedScalar, Bool8Scalar, ExtensionScalar,
237-
FixedShapeTensorScalar, JsonScalar, OpaqueScalar, UuidScalar)
237+
FixedShapeTensorScalar, JsonScalar, OpaqueScalar, UuidScalar,
238+
VariableShapeTensorScalar)
238239

239240
# Buffers, allocation
240241
from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager,

python/pyarrow/array.pxi

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4974,35 +4974,37 @@ cdef class VariableShapeTensorArray(ExtensionArray):
49744974
]
49754975
]
49764976
"""
4977-
assert isinstance(obj, list), 'obj must be a list of numpy arrays'
4977+
if not isinstance(obj, list) or len(obj) == 0:
4978+
raise TypeError('obj must be a non-empty list of numpy arrays')
49784979
numpy_type = obj[0].dtype
49794980
arrow_type = from_numpy_dtype(numpy_type)
49804981
ndim = obj[0].ndim
49814982
permutations = [(-np.array(o.strides)).argsort(kind="stable") for o in obj]
49824983
permutation = permutations[0]
49834984
shapes = [np.take(o.shape, permutation) for o in obj]
49844985

4985-
if not all([o.dtype == numpy_type for o in obj]):
4986+
if not all(o.dtype == numpy_type for o in obj):
49864987
raise TypeError('All numpy arrays must have matching dtype.')
49874988

4988-
if not all([o.ndim == ndim for o in obj]):
4989+
if not all(o.ndim == ndim for o in obj):
49894990
raise ValueError('All numpy arrays must have matching ndim.')
49904991

4991-
if not all([np.array_equal(p, permutation) for p in permutations]):
4992+
if not all(np.array_equal(p, permutation) for p in permutations):
49924993
raise ValueError('All numpy arrays must have matching permutation.')
49934994

49944995
for shape in shapes:
49954996
if len(shape) < 2:
49964997
raise ValueError(
4997-
"Cannot convert 1D array or scalar to fixed shape tensor array")
4998+
"Cannot convert 1D array or scalar to variable shape tensor array")
49984999
if np.prod(shape) == 0:
49995000
raise ValueError("Expected a non-empty ndarray")
50005001

50015002
values = array([np.ravel(o, order="K") for o in obj], list_(arrow_type))
50025003
shapes = array(shapes, list_(int32(), list_size=ndim))
50035004
struct_arr = StructArray.from_arrays([values, shapes], names=["data", "shape"])
50045005

5005-
return ExtensionArray.from_storage(variable_shape_tensor(arrow_type, ndim, permutation=permutation), struct_arr)
5006+
ext_type = variable_shape_tensor(arrow_type, ndim, permutation=permutation)
5007+
return ExtensionArray.from_storage(ext_type, struct_arr)
50065008

50075009

50085010
cdef dict _array_classes = {

python/pyarrow/tests/test_extension_type.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,16 +1751,16 @@ def test_variable_shape_tensor_array_from_numpy(value_type):
17511751
np.testing.assert_array_equal(result[0].to_numpy_ndarray(), expected)
17521752

17531753
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=value_type)
1754-
with pytest.raises(ValueError, match="Cannot convert 1D array or scalar to fixed"):
1754+
with pytest.raises(ValueError, match="Cannot convert 1D array or scalar to variable"):
17551755
pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
17561756

17571757
arr = np.array(1, dtype=value_type)
1758-
with pytest.raises(ValueError, match="Cannot convert 1D array or scalar to fixed"):
1758+
with pytest.raises(ValueError, match="Cannot convert 1D array or scalar to variable"):
17591759
pa.VariableShapeTensorArray.from_numpy_ndarray([arr])
17601760

17611761
arr = np.array([], dtype=value_type)
17621762

1763-
with pytest.raises(ValueError, match="Cannot convert 1D array or scalar to fixed"):
1763+
with pytest.raises(ValueError, match="Cannot convert 1D array or scalar to variable"):
17641764
pa.VariableShapeTensorArray.from_numpy_ndarray([arr.reshape((0))])
17651765

17661766
with pytest.raises(ValueError, match="Expected a non-empty ndarray"):
@@ -1969,7 +1969,7 @@ def test_variable_shape_tensor_type_is_picklable(pickle_module):
19691969
'fixed_shape_tensor[value_type=int64, shape=[2,2,3], dim_names=[C,H,W]]'
19701970
)
19711971
])
1972-
def test_tensor_type_str(tensor_type, text, pickle_module):
1972+
def test_tensor_type_str(tensor_type, text):
19731973
tensor_type_str = tensor_type.__str__()
19741974
assert text in tensor_type_str
19751975

python/pyarrow/types.pxi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,7 +2075,7 @@ cdef class VariableShapeTensorType(BaseExtensionType):
20752075

20762076
def __reduce__(self):
20772077
return variable_shape_tensor, (self.value_type, self.ndim,
2078-
self.permutation, self.dim_names, self.uniform_shape)
2078+
self.dim_names, self.permutation, self.uniform_shape)
20792079

20802080
def __arrow_ext_scalar_class__(self):
20812081
return VariableShapeTensorScalar
@@ -5936,8 +5936,10 @@ def variable_shape_tensor(DataType value_type, ndim, dim_names=None, permutation
59365936
vector[optional[int64_t]] c_uniform_shape
59375937
shared_ptr[CDataType] c_tensor_ext_type
59385938

5939-
assert value_type is not None
5940-
assert ndim is not None
5939+
if value_type is None:
5940+
raise TypeError('value_type must not be None')
5941+
if ndim is None:
5942+
raise TypeError('ndim must not be None')
59415943

59425944
c_ndim = ndim
59435945

0 commit comments

Comments
 (0)