Skip to content

Commit d86769b

Browse files
authored
refactor: make window size mandatory argument (#28)
* refactor: make window size mandatory argument * chore: remove stray space * chore: simplify defp
1 parent e42d747 commit d86769b

File tree

2 files changed

+81
-67
lines changed

2 files changed

+81
-67
lines changed

lib/nx_signal.ex

+5-5
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ defmodule NxSignal do
4343
4444
## Examples
4545
46-
iex> {z, t, f} = NxSignal.stft(Nx.iota({4}), NxSignal.Windows.rectangular(n: 2), overlap_length: 1, fft_length: 2, sampling_rate: 400)
46+
iex> {z, t, f} = NxSignal.stft(Nx.iota({4}), NxSignal.Windows.rectangular(2), overlap_length: 1, fft_length: 2, sampling_rate: 400)
4747
iex> z
4848
#Nx.Tensor<
4949
c64[frames: 3][frequencies: 2]
@@ -464,7 +464,7 @@ defmodule NxSignal do
464464
465465
iex> fft_length = 16
466466
iex> sampling_rate = 8.0e3
467-
iex> {z, _, _} = NxSignal.stft(Nx.iota({10}), NxSignal.Windows.hann(n: 4), overlap_length: 2, fft_length: fft_length, sampling_rate: sampling_rate, window_padding: :reflect)
467+
iex> {z, _, _} = NxSignal.stft(Nx.iota({10}), NxSignal.Windows.hann(4), overlap_length: 2, fft_length: fft_length, sampling_rate: sampling_rate, window_padding: :reflect)
468468
iex> Nx.axis_size(z, :frequencies)
469469
16
470470
iex> Nx.axis_size(z, :frames)
@@ -543,7 +543,7 @@ defmodule NxSignal do
543543
of the signal end up being distorted.
544544
545545
iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20])
546-
iex> w = NxSignal.Windows.hann(n: 4)
546+
iex> w = NxSignal.Windows.hann(4)
547547
iex> opts = [sampling_rate: 1, fft_length: 4]
548548
iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
549549
iex> result = NxSignal.istft(z, w, opts)
@@ -557,7 +557,7 @@ defmodule NxSignal do
557557
For perfect reconstruction, you want to use the same scaling as the STFT:
558558
559559
iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20])
560-
iex> w = NxSignal.Windows.hann(n: 4)
560+
iex> w = NxSignal.Windows.hann(4)
561561
iex> opts = [scaling: :spectrum, sampling_rate: 1, fft_length: 4]
562562
iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
563563
iex> result = NxSignal.istft(z, w, opts)
@@ -568,7 +568,7 @@ defmodule NxSignal do
568568
>
569569
570570
iex> t = Nx.tensor([10, 10, 1, 0, 10, 10, 2, 20], type: :f32)
571-
iex> w = NxSignal.Windows.hann(n: 4)
571+
iex> w = NxSignal.Windows.hann(4)
572572
iex> opts = [scaling: :psd, sampling_rate: 1, fft_length: 4]
573573
iex> {z, _time, _freqs} = NxSignal.stft(t, w, opts)
574574
iex> result = NxSignal.istft(z, w, opts)

lib/nx_signal/windows.ex

+76-62
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,25 @@ defmodule NxSignal.Windows do
1313
1414
## Options
1515
16-
* `:n` - the window length
1716
* `:type` - the output type. Defaults to `s64`
1817
1918
## Examples
2019
21-
iex> NxSignal.Windows.rectangular(n: 5)
20+
iex> NxSignal.Windows.rectangular(5)
2221
#Nx.Tensor<
2322
s64[5]
2423
[1, 1, 1, 1, 1]
2524
>
2625
27-
iex> NxSignal.Windows.rectangular(n: 5, type: :f32)
26+
iex> NxSignal.Windows.rectangular(5, type: :f32)
2827
#Nx.Tensor<
2928
f32[5]
3029
[1.0, 1.0, 1.0, 1.0, 1.0]
3130
>
3231
"""
3332
@doc type: :windowing
34-
defn rectangular(opts \\ []) do
35-
opts = keyword!(opts, [:n, type: :s64])
36-
{n, opts} = pop_window_size(opts)
33+
deftransform rectangular(n, opts \\ []) when is_integer(n) do
34+
opts = Keyword.validate!(opts, type: :s64)
3735
Nx.broadcast(Nx.tensor(1, type: opts[:type]), {n})
3836
end
3937

@@ -44,22 +42,25 @@ defmodule NxSignal.Windows do
4442
4543
## Options
4644
47-
* `:n` - The window length. Mandatory option.
4845
* `:type` - the output type for the window. Defaults to `{:f, 32}`
4946
* `:name` - the axis name. Defaults to `nil`
5047
5148
## Examples
5249
53-
iex> NxSignal.Windows.bartlett(n: 3)
50+
iex> NxSignal.Windows.bartlett(3)
5451
#Nx.Tensor<
5552
f32[3]
5653
[0.0, 0.6666666865348816, 0.6666666269302368]
5754
>
5855
"""
5956
@doc type: :windowing
60-
defn bartlett(opts \\ []) do
61-
opts = keyword!(opts, [:n, :name, type: {:f, 32}])
62-
{n, opts} = pop_window_size(opts)
57+
deftransform bartlett(n, opts \\ []) when is_integer(n) do
58+
opts = Keyword.validate!(opts, type: {:f, 32})
59+
bartlett_n(Keyword.put(opts, :n, n))
60+
end
61+
62+
defnp bartlett_n(opts) do
63+
n = opts[:n]
6364
name = opts[:name]
6465
type = opts[:type]
6566

@@ -87,16 +88,20 @@ defmodule NxSignal.Windows do
8788
8889
## Examples
8990
90-
iex> NxSignal.Windows.triangular(n: 3)
91+
iex> NxSignal.Windows.triangular(3)
9192
#Nx.Tensor<
9293
f32[3]
9394
[0.5, 1.0, 0.5]
9495
>
9596
"""
9697
@doc type: :windowing
97-
defn triangular(opts \\ []) do
98-
opts = keyword!(opts, [:n, :name, type: {:f, 32}])
99-
{n, opts} = pop_window_size(opts)
98+
deftransform triangular(n, opts \\ []) when is_integer(n) do
99+
opts = Keyword.validate!(opts, [:name, type: {:f, 32}])
100+
triangular_n(Keyword.put(opts, :n, n))
101+
end
102+
103+
defnp triangular_n(opts) do
104+
n = opts[:n]
100105
name = opts[:name]
101106
type = opts[:type]
102107

@@ -126,48 +131,52 @@ defmodule NxSignal.Windows do
126131
127132
## Options
128133
129-
* `:n` - The window length. Mandatory option.
130134
* `:is_periodic` - If `true`, produces a periodic window,
131135
otherwise produces a symmetric window. Defaults to `true`
132136
* `:type` - the output type for the window. Defaults to `{:f, 32}`
133137
* `:name` - the axis name. Defaults to `nil`
134138
135139
## Examples
136140
137-
iex> NxSignal.Windows.blackman(n: 5, is_periodic: false)
141+
iex> NxSignal.Windows.blackman(5, is_periodic: false)
138142
#Nx.Tensor<
139143
f32[5]
140144
[-1.4901161193847656e-8, 0.3400000333786011, 0.9999999403953552, 0.3400000333786011, -1.4901161193847656e-8]
141145
>
142146
143-
iex> NxSignal.Windows.blackman(n: 5, is_periodic: true)
147+
iex> NxSignal.Windows.blackman(5, is_periodic: true)
144148
#Nx.Tensor<
145149
f32[5]
146150
[-1.4901161193847656e-8, 0.20077012479305267, 0.8492299318313599, 0.8492299318313599, 0.20077012479305267]
147151
>
148152
149-
iex> NxSignal.Windows.blackman(n: 6, is_periodic: true, type: {:f, 32})
153+
iex> NxSignal.Windows.blackman(6, is_periodic: true, type: {:f, 32})
150154
#Nx.Tensor<
151155
f32[6]
152156
[-1.4901161193847656e-8, 0.12999999523162842, 0.6299999952316284, 0.9999999403953552, 0.6299999952316284, 0.12999999523162842]
153157
>
154158
"""
155159
@doc type: :windowing
156-
defn blackman(opts \\ []) do
157-
opts = keyword!(opts, [:n, :name, is_periodic: true, type: {:f, 32}])
158-
{l, opts} = pop_window_size(opts)
160+
deftransform blackman(n, opts \\ []) when is_integer(n) do
161+
opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}])
162+
blackman_n(Keyword.put(opts, :n, n))
163+
end
164+
165+
defnp blackman_n(opts) do
166+
n = opts[:n]
159167
name = opts[:name]
160168
type = opts[:type]
161169
is_periodic = opts[:is_periodic]
162170

163171
l =
164172
if is_periodic do
165-
l + 1
173+
n + 1
166174
else
167-
l
175+
n
168176
end
169177

170-
m = div_ceil(l, 2)
178+
m =
179+
integer_div_ceil(l, 2)
171180

172181
n = Nx.iota({m}, names: [name], type: type)
173182

@@ -194,38 +203,41 @@ defmodule NxSignal.Windows do
194203
195204
## Options
196205
197-
* `:n` - The window length. Mandatory option.
198206
* `:is_periodic` - If `true`, produces a periodic window,
199207
otherwise produces a symmetric window. Defaults to `true`
200208
* `:type` - the output type for the window. Defaults to `{:f, 32}`
201209
* `:name` - the axis name. Defaults to `nil`
202210
203211
## Examples
204212
205-
iex> NxSignal.Windows.hamming(n: 5, is_periodic: true)
213+
iex> NxSignal.Windows.hamming(5, is_periodic: true)
206214
#Nx.Tensor<
207215
f32[5]
208216
[0.08000001311302185, 0.39785221219062805, 0.9121478796005249, 0.9121478199958801, 0.3978521227836609]
209217
>
210-
iex> NxSignal.Windows.hamming(n: 5, is_periodic: false)
218+
iex> NxSignal.Windows.hamming(5, is_periodic: false)
211219
#Nx.Tensor<
212220
f32[5]
213221
[0.08000001311302185, 0.5400000214576721, 1.0, 0.5400000214576721, 0.08000001311302185]
214222
>
215223
"""
216224
@doc type: :windowing
217-
defn hamming(opts \\ []) do
218-
opts = keyword!(opts, [:n, :name, is_periodic: true, type: {:f, 32}])
219-
{l, opts} = pop_window_size(opts)
225+
deftransform hamming(n, opts \\ []) when is_integer(n) do
226+
opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}])
227+
hamming_n(Keyword.put(opts, :n, n))
228+
end
229+
230+
defnp hamming_n(opts) do
231+
n = opts[:n]
220232
name = opts[:name]
221233
type = opts[:type]
222234
is_periodic = opts[:is_periodic]
223235

224236
l =
225237
if is_periodic do
226-
l + 1
238+
n + 1
227239
else
228-
l
240+
n
229241
end
230242

231243
n = Nx.iota({l}, names: [name], type: type)
@@ -244,38 +256,41 @@ defmodule NxSignal.Windows do
244256
245257
## Options
246258
247-
* `:n` - The window length. Mandatory option.
248259
* `:is_periodic` - If `true`, produces a periodic window,
249260
otherwise produces a symmetric window. Defaults to `true`
250261
* `:type` - the output type for the window. Defaults to `{:f, 32}`
251262
* `:name` - the axis name. Defaults to `nil`
252263
253264
## Examples
254265
255-
iex> NxSignal.Windows.hann(n: 5, is_periodic: false)
266+
iex> NxSignal.Windows.hann(5, is_periodic: false)
256267
#Nx.Tensor<
257268
f32[5]
258269
[0.0, 0.5, 1.0, 0.5, 0.0]
259270
>
260-
iex> NxSignal.Windows.hann(n: 5, is_periodic: true)
271+
iex> NxSignal.Windows.hann(5, is_periodic: true)
261272
#Nx.Tensor<
262273
f32[5]
263274
[0.0, 0.34549152851104736, 0.9045085310935974, 0.9045084714889526, 0.3454914391040802]
264275
>
265276
"""
266277
@doc type: :windowing
267-
defn hann(opts \\ []) do
268-
opts = keyword!(opts, [:n, :name, is_periodic: true, type: {:f, 32}])
269-
{l, opts} = pop_window_size(opts)
278+
deftransform hann(n, opts \\ []) when is_integer(n) do
279+
opts = Keyword.validate!(opts, [:name, is_periodic: true, type: {:f, 32}])
280+
hann_n(Keyword.put(opts, :n, n))
281+
end
282+
283+
defnp hann_n(opts) do
284+
n = opts[:n]
270285
name = opts[:name]
271286
type = opts[:type]
272287
is_periodic = opts[:is_periodic]
273288

274289
l =
275290
if is_periodic do
276-
l + 1
291+
n + 1
277292
else
278-
l
293+
n
279294
end
280295

281296
n = Nx.iota({l}, names: [name], type: type)
@@ -296,7 +311,6 @@ defmodule NxSignal.Windows do
296311
297312
## Options
298313
299-
* `:n` - The window length. Mandatory option.
300314
* `:is_periodic` - If `true`, produces a periodic window,
301315
otherwise produces a symmetric window. Defaults to `true`
302316
* `:type` - the output type for the window. Defaults to `{:f, 32}`
@@ -305,46 +319,50 @@ defmodule NxSignal.Windows do
305319
* `:axis_name` - the axis name. Defaults to `nil`
306320
307321
## Examples
308-
iex> NxSignal.Windows.kaiser(n: 4, beta: 12.0, is_periodic: true)
322+
iex> NxSignal.Windows.kaiser(4, beta: 12.0, is_periodic: true)
309323
#Nx.Tensor<
310324
f32[4]
311325
[5.2776191296288744e-5, 0.21566666662693024, 1.0, 0.21566666662693024]
312326
>
313327
314-
iex> NxSignal.Windows.kaiser(n: 5, beta: 12.0, is_periodic: true)
328+
iex> NxSignal.Windows.kaiser(5, beta: 12.0, is_periodic: true)
315329
#Nx.Tensor<
316330
f32[5]
317331
[5.2776191296288744e-5, 0.10171464085578918, 0.7929369807243347, 0.7929369807243347, 0.10171464085578918]
318332
>
319333
320-
iex> NxSignal.Windows.kaiser(n: 4, beta: 12.0, is_periodic: false)
334+
iex> NxSignal.Windows.kaiser(4, beta: 12.0, is_periodic: false)
321335
#Nx.Tensor<
322336
f32[4]
323337
[5.2776191296288744e-5, 0.5188394784927368, 0.5188390612602234, 5.2776191296288744e-5]
324338
>
325339
"""
326340
@doc type: :windowing
327-
defn kaiser(opts \\ []) do
341+
deftransform kaiser(n, opts \\ []) when is_integer(n) do
328342
opts =
329-
keyword!(opts, [:n, :axis_name, eps: 1.0e-7, beta: 12.0, is_periodic: true, type: {:f, 32}])
343+
Keyword.validate!(opts, [:name, eps: 1.0e-7, beta: 12.0, is_periodic: true, type: {:f, 32}])
344+
345+
kaiser_n(Keyword.put(opts, :n, n))
346+
end
330347

331-
{l, opts} = pop_window_size(opts)
332-
name = opts[:axis_name]
348+
defnp kaiser_n(opts) do
349+
n = opts[:n]
350+
name = opts[:name]
333351
type = opts[:type]
334352
beta = opts[:beta]
335353
eps = opts[:eps]
336354
is_periodic = opts[:is_periodic]
337355

338-
window_length = if is_periodic, do: l + 1, else: l
356+
window_length = if is_periodic, do: n + 1, else: n
339357

340-
ratio = Nx.linspace(-1, 1, n: window_length, endpoint: true, type: type) |> Nx.rename([name])
358+
ratio = Nx.linspace(-1, 1, n: window_length, endpoint: true, type: type, name: name)
341359
sqrt_arg = Nx.max(1 - ratio ** 2, eps)
342360
r = beta * Nx.sqrt(sqrt_arg)
343361

344362
window = kaiser_bessel_i0(r) / kaiser_bessel_i0(beta)
345363

346364
if is_periodic do
347-
Nx.slice(window, [0], [l])
365+
Nx.slice(window, [0], [n])
348366
else
349367
window
350368
end
@@ -367,17 +385,13 @@ defmodule NxSignal.Windows do
367385
Nx.select(abs_x < 3.75, small_x_result, large_x_result)
368386
end
369387

370-
deftransformp pop_window_size(opts) do
371-
{n, opts} = Keyword.pop(opts, :n)
388+
deftransformp integer_div_ceil(num, den) when is_integer(num) and is_integer(den) do
389+
rem = rem(num, den)
372390

373-
if !n do
374-
raise "missing :n option"
391+
if rem == 0 do
392+
div(num, den)
393+
else
394+
div(num, den) + 1
375395
end
376-
377-
{n, opts}
378-
end
379-
380-
deftransformp div_ceil(num, den) do
381-
ceil(num / den)
382396
end
383397
end

0 commit comments

Comments
 (0)