@@ -4974,35 +4974,37 @@ cdef class VariableShapeTensorArray(ExtensionArray):
49744974 ]
49754975 ]
49764976 """
4977- assert isinstance (obj, list ), ' obj must be a list of numpy arrays'
4977+ if not isinstance (obj, list ) or len (obj) == 0 :
4978+ raise TypeError (' obj must be a non-empty list of numpy arrays' )
49784979 numpy_type = obj[0 ].dtype
49794980 arrow_type = from_numpy_dtype(numpy_type)
49804981 ndim = obj[0 ].ndim
49814982 permutations = [(- np.array(o.strides)).argsort(kind = " stable" ) for o in obj]
49824983 permutation = permutations[0 ]
49834984 shapes = [np.take(o.shape, permutation) for o in obj]
49844985
4985- if not all ([ o.dtype == numpy_type for o in obj] ):
4986+ if not all (o.dtype == numpy_type for o in obj):
49864987 raise TypeError (' All numpy arrays must have matching dtype.' )
49874988
4988- if not all ([ o.ndim == ndim for o in obj] ):
4989+ if not all (o.ndim == ndim for o in obj):
49894990 raise ValueError (' All numpy arrays must have matching ndim.' )
49904991
4991- if not all ([ np.array_equal(p, permutation) for p in permutations] ):
4992+ if not all (np.array_equal(p, permutation) for p in permutations):
49924993 raise ValueError (' All numpy arrays must have matching permutation.' )
49934994
49944995 for shape in shapes:
49954996 if len (shape) < 2 :
49964997 raise ValueError (
4997- " Cannot convert 1D array or scalar to fixed shape tensor array" )
4998+ " Cannot convert 1D array or scalar to variable shape tensor array" )
49984999 if np.prod(shape) == 0 :
49995000 raise ValueError (" Expected a non-empty ndarray" )
50005001
50015002 values = array([np.ravel(o, order = " K" ) for o in obj], list_(arrow_type))
50025003 shapes = array(shapes, list_(int32(), list_size = ndim))
50035004 struct_arr = StructArray.from_arrays([values, shapes], names = [" data" , " shape" ])
50045005
5005- return ExtensionArray.from_storage(variable_shape_tensor(arrow_type, ndim, permutation = permutation), struct_arr)
5006+ ext_type = variable_shape_tensor(arrow_type, ndim, permutation = permutation)
5007+ return ExtensionArray.from_storage(ext_type, struct_arr)
50065008
50075009
50085010cdef dict _array_classes = {
0 commit comments