Skip to content

Commit a657583

Browse files
sakchalChaluvadi
and
Chaluvadi
authored
Added unit tests for the pad function and fixed pad function return v… (#22)
* Added unit tests for the pad function and fixed pad function return value * fixed import formatting, black and flake8 automatic checks --------- Co-authored-by: Chaluvadi <[email protected]>
1 parent e56731e commit a657583

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

arrayfire_wrapper/lib/create_and_modify_array/create_array/pad.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ def pad(arr: AFArray, begin_shape: tuple[int, ...], end_shape: tuple[int, ...],
2222
end_c_shape.c_array,
2323
border_type.value,
2424
)
25-
return NotImplemented
25+
return out

tests/test_pad.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import random
2+
3+
import numpy as np
4+
import pytest
5+
6+
import arrayfire_wrapper.dtypes as dtypes
7+
import arrayfire_wrapper.lib as wrapper
8+
9+
10+
@pytest.mark.parametrize(
11+
"original_shape",
12+
[
13+
(random.randint(1, 100),),
14+
(random.randint(1, 100), random.randint(1, 100)),
15+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
16+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
17+
],
18+
)
19+
def test_zero_padding(original_shape: tuple) -> None:
20+
"""Test if pad creates an array with no padding if no padding is given"""
21+
original_array = wrapper.constant(2, original_shape, dtypes.s64)
22+
padding = wrapper.Pad(0)
23+
24+
zero_shape = tuple(0 for _ in range(len(original_shape)))
25+
result = wrapper.pad(original_array, zero_shape, zero_shape, padding)
26+
27+
assert wrapper.get_dims(result)[0 : len(original_shape)] == original_shape # noqa: E203
28+
29+
30+
@pytest.mark.parametrize(
31+
"original_shape",
32+
[
33+
(random.randint(1, 100),),
34+
(random.randint(1, 100), random.randint(1, 100)),
35+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
36+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
37+
],
38+
)
39+
def test_negative_padding(original_shape: tuple) -> None:
40+
"""Test if pad can properly handle if negative padding is given"""
41+
with pytest.raises(RuntimeError):
42+
original_array = wrapper.constant(2, original_shape, dtypes.s64)
43+
padding = wrapper.Pad(0)
44+
45+
neg_shape = tuple(-1 for _ in range(len(original_shape)))
46+
result = wrapper.pad(original_array, neg_shape, neg_shape, padding)
47+
48+
assert wrapper.get_dims(result)[0 : len(original_shape)] == original_shape # noqa: E203
49+
50+
51+
@pytest.mark.parametrize(
52+
"original_shape",
53+
[
54+
(random.randint(1, 100),),
55+
(random.randint(1, 100), random.randint(1, 100)),
56+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
57+
(random.randint(1, 100), random.randint(1, 100), random.randint(1, 100), random.randint(1, 100)),
58+
],
59+
)
60+
def test_padding_shape(original_shape: tuple) -> None:
61+
"""Test if pad outputs the correct shape when a padding is adding to the original array"""
62+
original_array = wrapper.constant(2, original_shape, dtypes.s64)
63+
padding = wrapper.Pad(0)
64+
65+
beg_shape = tuple(random.randint(1, 10) for _ in range(len(original_shape)))
66+
end_shape = tuple(random.randint(1, 10) for _ in range(len(original_shape)))
67+
68+
result = wrapper.pad(original_array, beg_shape, end_shape, padding)
69+
new_shape = np.array(beg_shape) + np.array(end_shape) + np.array(original_shape)
70+
71+
assert wrapper.get_dims(result)[0 : len(original_shape)] == tuple(new_shape) # noqa: E203

0 commit comments

Comments
 (0)