Skip to content

Commit 3088fd4

Browse files
Generalize median filter to N dims (#21)
* generalize median filter to N dims * run mix format * add n-dim test
1 parent d0b7df4 commit 3088fd4

File tree

2 files changed

+86
-41
lines changed

2 files changed

+86
-41
lines changed

lib/nx_signal/filters.ex

+21-21
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,24 @@ defmodule NxSignal.Filters do
55
import Nx.Defn
66

77
@doc ~S"""
8-
Performs a median filter on a rank 1 or rank 2 tensor.
8+
Performs a median filter on a tensor.
99
1010
## Options
1111
1212
* `:kernel_shape` - the shape of the sliding window.
1313
It must be compatible with the shape of the tensor.
1414
"""
1515
@doc type: :filters
16-
deftransform median(t = %Nx.Tensor{shape: {length}}, opts) do
16+
defn median(t, opts) do
1717
validate_median_opts!(t, opts)
18-
{kernel_length} = opts[:kernel_shape]
19-
20-
median(Nx.reshape(t, {1, length}), kernel_shape: {1, kernel_length})
21-
|> Nx.squeeze()
22-
end
23-
24-
deftransform median(t = %Nx.Tensor{shape: {_h, _w}}, opts) do
25-
validate_median_opts!(t, opts)
26-
median_n(t, opts)
27-
end
28-
29-
deftransform median(_t, _opts),
30-
do: raise(ArgumentError, message: "tensor must be of rank 1 or 2")
31-
32-
defn median_n(t, opts) do
33-
{k0, k1} = opts[:kernel_shape]
3418

3519
idx =
36-
Nx.stack([Nx.iota(t.shape, axis: 0), Nx.iota(t.shape, axis: 1)], axis: -1)
37-
|> Nx.reshape({:auto, 2})
20+
t
21+
|> idx_tensor()
3822
|> Nx.vectorize(:elements)
3923

4024
t
41-
|> Nx.slice([idx[0], idx[1]], [k0, k1])
25+
|> Nx.slice(start_indices(t, idx), kernel_lengths(opts[:kernel_shape]))
4226
|> Nx.median()
4327
|> Nx.devectorize(keep_names: false)
4428
|> Nx.reshape(t.shape)
@@ -52,4 +36,20 @@ defmodule NxSignal.Filters do
5236
raise ArgumentError, message: "kernel shape must be of the same rank as the tensor"
5337
end
5438
end
39+
40+
deftransformp idx_tensor(t) do
41+
t
42+
|> Nx.axes()
43+
|> Enum.map(&Nx.iota(t.shape, axis: &1))
44+
|> Nx.stack(axis: -1)
45+
|> Nx.reshape({:auto, length(Nx.axes(t))})
46+
end
47+
48+
deftransformp start_indices(t, idx_tensor) do
49+
t
50+
|> Nx.axes()
51+
|> Enum.map(&idx_tensor[&1])
52+
end
53+
54+
deftransformp kernel_lengths(kernel_shape), do: Tuple.to_list(kernel_shape)
5555
end

test/nx_signal/filters_test.exs

+65-20
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,71 @@ defmodule NxSignal.FiltersTest do
3131
assert NxSignal.Filters.median(t, opts) == expected
3232
end
3333

34+
test "performs n-dim median filter" do
35+
t =
36+
Nx.tensor([
37+
[
38+
[31, 11, 17, 13, 1],
39+
[1, 3, 19, 23, 29],
40+
[19, 5, 7, 37, 2]
41+
],
42+
[
43+
[19, 5, 7, 37, 2],
44+
[1, 3, 19, 23, 29],
45+
[31, 11, 17, 13, 1]
46+
],
47+
[
48+
[1, 3, 19, 23, 29],
49+
[31, 11, 17, 13, 1],
50+
[19, 5, 7, 37, 2]
51+
]
52+
])
53+
54+
k1 = {3, 3, 1}
55+
k2 = {3, 3, 3}
56+
57+
expected1 =
58+
Nx.tensor([
59+
[
60+
[19.0, 5.0, 17.0, 23.0, 2.0],
61+
[19.0, 5.0, 17.0, 23.0, 2.0],
62+
[19.0, 5.0, 17.0, 23.0, 2.0]
63+
],
64+
[
65+
[19.0, 5.0, 17.0, 23.0, 2.0],
66+
[19.0, 5.0, 17.0, 23.0, 2.0],
67+
[19.0, 5.0, 17.0, 23.0, 2.0]
68+
],
69+
[
70+
[19.0, 5.0, 17.0, 23.0, 2.0],
71+
[19.0, 5.0, 17.0, 23.0, 2.0],
72+
[19.0, 5.0, 17.0, 23.0, 2.0]
73+
]
74+
])
75+
76+
expected2 =
77+
Nx.tensor([
78+
[
79+
[11.0, 13.0, 17.0, 17.0, 17.0],
80+
[11.0, 13.0, 17.0, 17.0, 17.0],
81+
[11.0, 13.0, 17.0, 17.0, 17.0]
82+
],
83+
[
84+
[11.0, 13.0, 17.0, 17.0, 17.0],
85+
[11.0, 13.0, 17.0, 17.0, 17.0],
86+
[11.0, 13.0, 17.0, 17.0, 17.0]
87+
],
88+
[
89+
[11.0, 13.0, 17.0, 17.0, 17.0],
90+
[11.0, 13.0, 17.0, 17.0, 17.0],
91+
[11.0, 13.0, 17.0, 17.0, 17.0]
92+
]
93+
])
94+
95+
assert NxSignal.Filters.median(t, kernel_shape: k1) == expected1
96+
assert NxSignal.Filters.median(t, kernel_shape: k2) == expected2
97+
end
98+
3499
test "raises if kernel_shape is not compatible" do
35100
t1 = Nx.iota({10})
36101
opts1 = [kernel_shape: {5, 5}]
@@ -50,25 +115,5 @@ defmodule NxSignal.FiltersTest do
50115
fn -> NxSignal.Filters.median(t2, opts2) end
51116
)
52117
end
53-
54-
test "raises if tensor rank is not 1 or 2" do
55-
t1 = Nx.tensor(1)
56-
opts1 = [kernel_shape: {1}]
57-
58-
assert_raise(
59-
ArgumentError,
60-
"tensor must be of rank 1 or 2",
61-
fn -> NxSignal.Filters.median(t1, opts1) end
62-
)
63-
64-
t2 = Nx.iota({5, 5, 5})
65-
opts2 = [kernel_shape: {3, 3, 3}]
66-
67-
assert_raise(
68-
ArgumentError,
69-
"tensor must be of rank 1 or 2",
70-
fn -> NxSignal.Filters.median(t2, opts2) end
71-
)
72-
end
73118
end
74119
end

0 commit comments

Comments
 (0)