@@ -355,7 +355,7 @@ def test_sign_shapes_invalid(invdtypes: dtype.Dtype) -> None:
355355)
356356@pytest .mark .parametrize ("dtype_name" , util .get_real_types ())
357357def test_trunc_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
358- """Test truncating operation between two arrays of the same shape"""
358+ """Test truncating operation for an array with varying shape"""
359359 util .check_type_supported (dtype_name )
360360 out = wrapper .randu (shape , dtype_name )
361361
@@ -366,7 +366,7 @@ def test_trunc_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
366366
367367@pytest .mark .parametrize ("invdtypes" , util .get_complex_types ())
368368def test_trunc_shapes_invalid (invdtypes : dtype .Dtype ) -> None :
369- """Test trunc operation between two arrays of the same shape"""
369+ """Test trunc operation for an array with varrying shape and invalid dtypes """
370370 with pytest .raises (RuntimeError ):
371371 shape = (3 , 3 )
372372 out = wrapper .randu (shape , invdtypes )
@@ -408,4 +408,26 @@ def test_hypot_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
408408 shape = (5 , 5 )
409409 lhs = wrapper .randu (shape , invdtypes )
410410 rhs = wrapper .randu (shape , invdtypes )
411- wrapper .hypot (rhs , lhs , True )
411+ wrapper .hypot (rhs , lhs , True )
412+ @pytest .mark .parametrize (
413+ "shape" ,
414+ [
415+ (),
416+ (random .randint (1 , 10 ),),
417+ (random .randint (1 , 10 ), random .randint (1 , 10 )),
418+ (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
419+ (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
420+ ],
421+ )
422+ @pytest .mark .parametrize ("dtype_name" , util .get_real_types ())
423+ def test_clamp_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
424+ """Test clamp operation between two arrays of the same shape"""
425+ util .check_type_supported (dtype_name )
426+ og = wrapper .randu (shape , dtype_name )
427+ low = wrapper .randu (shape , dtype_name )
428+ high = wrapper .randu (shape , dtype_name )
429+ # talked to stefan about this, testing broadcasting is unnecessary
430+ result = wrapper .clamp (og , low , high , False )
431+ assert (
432+ wrapper .get_dims (result )[0 : len (shape )] == shape # noqa
433+ ), f"failed for shape: { shape } and dtype { dtype_name } "
0 commit comments