@@ -31,6 +31,71 @@ defmodule NxSignal.FiltersTest do
31
31
assert NxSignal.Filters . median ( t , opts ) == expected
32
32
end
33
33
34
+ test "performs n-dim median filter" do
35
+ t =
36
+ Nx . tensor ( [
37
+ [
38
+ [ 31 , 11 , 17 , 13 , 1 ] ,
39
+ [ 1 , 3 , 19 , 23 , 29 ] ,
40
+ [ 19 , 5 , 7 , 37 , 2 ]
41
+ ] ,
42
+ [
43
+ [ 19 , 5 , 7 , 37 , 2 ] ,
44
+ [ 1 , 3 , 19 , 23 , 29 ] ,
45
+ [ 31 , 11 , 17 , 13 , 1 ]
46
+ ] ,
47
+ [
48
+ [ 1 , 3 , 19 , 23 , 29 ] ,
49
+ [ 31 , 11 , 17 , 13 , 1 ] ,
50
+ [ 19 , 5 , 7 , 37 , 2 ]
51
+ ]
52
+ ] )
53
+
54
+ k1 = { 3 , 3 , 1 }
55
+ k2 = { 3 , 3 , 3 }
56
+
57
+ expected1 =
58
+ Nx . tensor ( [
59
+ [
60
+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
61
+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
62
+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ]
63
+ ] ,
64
+ [
65
+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
66
+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
67
+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ]
68
+ ] ,
69
+ [
70
+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
71
+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ] ,
72
+ [ 19.0 , 5.0 , 17.0 , 23.0 , 2.0 ]
73
+ ]
74
+ ] )
75
+
76
+ expected2 =
77
+ Nx . tensor ( [
78
+ [
79
+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
80
+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
81
+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ]
82
+ ] ,
83
+ [
84
+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
85
+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
86
+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ]
87
+ ] ,
88
+ [
89
+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
90
+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ] ,
91
+ [ 11.0 , 13.0 , 17.0 , 17.0 , 17.0 ]
92
+ ]
93
+ ] )
94
+
95
+ assert NxSignal.Filters . median ( t , kernel_shape: k1 ) == expected1
96
+ assert NxSignal.Filters . median ( t , kernel_shape: k2 ) == expected2
97
+ end
98
+
34
99
test "raises if kernel_shape is not compatible" do
35
100
t1 = Nx . iota ( { 10 } )
36
101
opts1 = [ kernel_shape: { 5 , 5 } ]
@@ -50,25 +115,5 @@ defmodule NxSignal.FiltersTest do
50
115
fn -> NxSignal.Filters . median ( t2 , opts2 ) end
51
116
)
52
117
end
53
-
54
- test "raises if tensor rank is not 1 or 2" do
55
- t1 = Nx . tensor ( 1 )
56
- opts1 = [ kernel_shape: { 1 } ]
57
-
58
- assert_raise (
59
- ArgumentError ,
60
- "tensor must be of rank 1 or 2" ,
61
- fn -> NxSignal.Filters . median ( t1 , opts1 ) end
62
- )
63
-
64
- t2 = Nx . iota ( { 5 , 5 , 5 } )
65
- opts2 = [ kernel_shape: { 3 , 3 , 3 } ]
66
-
67
- assert_raise (
68
- ArgumentError ,
69
- "tensor must be of rank 1 or 2" ,
70
- fn -> NxSignal.Filters . median ( t2 , opts2 ) end
71
- )
72
- end
73
118
end
74
119
end
0 commit comments