-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathcwt.py
218 lines (199 loc) · 9.98 KB
/
cwt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""Module that implements the CWT using a trainable complex morlet wavelet. By Nicolas I. Tapia"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class ContinuousWaveletTransform(object):
"""CWT layer implementation in Tensorflow for GPU acceleration."""
def __init__(self, n_scales, border_crop=0, stride=1, name="cwt"):
"""
Args:
n_scales: (int) Number of scales for the scalogram.
border_crop: (int) Non-negative integer that specifies the number
of samples to be removed at each border after computing the cwt.
This parameter allows to input a longer signal than the final
desired size to remove border effects of the CWT. Default 0.
stride: (int) The stride of the sliding window across the input.
Default is 1.
name: (string) A name for the op. Default "cwt".
"""
self.n_scales = n_scales
self.border_crop = border_crop
self.stride = stride
self.name = name
with tf.variable_scope(self.name):
self.real_part, self.imaginary_part = self._build_wavelet_bank()
def _build_wavelet_bank(self):
"""Needs implementation to compute the real and imaginary parts
of the wavelet bank. Each part is expected to have shape
[1, kernel_size, 1, n_scales]."""
real_part = None
imaginary_part = None
return real_part, imaginary_part
def __call__(self, inputs):
"""
Computes the CWT with the specified wavelet bank.
If the signal has more than one channel, the CWT is computed for
each channel independently and stacked at the end along the
channel axis.
Args:
inputs: (tensor) A batch of 1D tensors of shape
[batch_size, time_len, n_channels].
Returns:
Scalogram tensor with real and imaginary parts for each input
channels. The shape of this tensor is
[batch_size, time_len, n_scales, 2 * n_channels]
"""
# Generate the scalogram
border_crop = int(self.border_crop / self.stride)
start = border_crop
end = (-border_crop) if (border_crop > 0) else None
with tf.variable_scope(self.name):
# Input has expected shape of [batch_size, time_len, n_channels]
# We first unstack the input channels
inputs_unstacked = tf.unstack(inputs, axis=2)
multi_channel_cwt = []
for j, single_channel in enumerate(inputs_unstacked):
# Reshape input [batch, time_len] -> [batch, 1, time_len, 1]
inputs_expand = tf.expand_dims(single_channel, axis=1)
inputs_expand = tf.expand_dims(inputs_expand, axis=3)
with tf.name_scope('%s_%d' % (self.name, j)):
bank_real = self.real_part
bank_imag = -self.imaginary_part # Conjugation
out_real = tf.nn.conv2d(
input=inputs_expand, filter=bank_real,
strides=[1, 1, self.stride, 1], padding="SAME")
out_imag = tf.nn.conv2d(
input=inputs_expand, filter=bank_imag,
strides=[1, 1, self.stride, 1], padding="SAME")
out_real_crop = out_real[:, :, start:end, :]
out_imag_crop = out_imag[:, :, start:end, :]
out_concat = tf.concat(
[out_real_crop, out_imag_crop], axis=1)
# [batch, 2, time, n_scales]->[batch, time, n_scales, 2]
single_scalogram = tf.transpose(
out_concat, perm=[0, 2, 3, 1])
multi_channel_cwt.append(single_scalogram)
# Get all in shape [batch, time_len, n_scales, 2*n_channels]
scalograms = tf.concat(multi_channel_cwt, -1)
return scalograms
class ComplexMorletCWT(ContinuousWaveletTransform):
"""CWT with the complex Morlet wavelet filter bank."""
def __init__(
self,
wavelet_width,
fs,
lower_freq,
upper_freq,
n_scales,
size_factor=1.0,
trainable=False,
border_crop=0,
stride=1,
name="cwt"):
"""
Computes the complex morlet wavelets
The mother wavelet is defined as:
PSI(t) = (1 / Z) * exp(j * 2 * pi * t) * exp(-(t^2) / beta)
Where:
beta: wavelet width
t: k / fs the time axis
Z: A normalization constant that depends on beta. We want to
have unit gain at each scale, so we use:
Z: fs * sqrt(pi * beta) / 2
And the scaled wavelets are computed as:
PSI_s(t) = PSI(t / scale) / scale
Greater wavelet widths lead to more duration of the wavelet in time,
leading to better frequency resolution but worse time resolution.
Scales will be automatically computed from the given frequency range and
the number of desired scales. The scales increase exponentially
as commonly recommended.
A gaussian window is commonly truncated at 3 standard deviations from
the mean. Therefore, we truncate the wavelets at the interval
|t| <= size_factor * scale * sqrt(4.5 * wavelet_width)
where size_factor >= 1 can be optionally set to relax this truncation.
This might be useful when allowing the wavelet width to be trainable.
Given this heuristic, the wavelet width can be thought in terms of the
number of effective cycles that the wavelet completes.
If you want N effective cycles, then you should set
beta = N^2 / 18
For example, 4 effective cycles are observed when beta approx 0.9.
Args:
wavelet_width: (float o tensor) wavelet width.
fs: (float) Sampling frequency of the application.
lower_freq: (float) Lower frequency of the scalogram.
upper_freq: (float) Upper frequency of the scalogram.
n_scales: (int) Number of scales for the scalogram.
size_factor: (float) Factor by which the size of the kernels will
be increased with respect to the original size. Default 1.0.
trainable: (boolean) If True, the wavelet width is trainable.
Default to False.
border_crop: (int) Non-negative integer that specifies the number
of samples to be removed at each border after computing the cwt.
This parameter allows to input a longer signal than the final
desired size to remove border effects of the CWT. Default 0.
stride: (int) The stride of the sliding window across the input.
Default is 1.
name: (string) A name for the op. Default "cwt".
"""
# Checking
if lower_freq > upper_freq:
raise ValueError("lower_freq should be lower than upper_freq")
if lower_freq < 0:
raise ValueError("Expected positive lower_freq.")
self.initial_wavelet_width = wavelet_width
self.fs = fs
self.lower_freq = lower_freq
self.upper_freq = upper_freq
self.size_factor = size_factor
self.trainable = trainable
# Generate initial and last scale
s_0 = 1 / self.upper_freq
s_n = 1 / self.lower_freq
# Generate the array of scales
base = np.power(s_n / s_0, 1 / (n_scales - 1))
self.scales = s_0 * np.power(base, np.arange(n_scales))
# Generate the frequency range
self.frequencies = 1 / self.scales
# Trainable wavelet width value
self.wavelet_width = tf.Variable(
initial_value=self.initial_wavelet_width,
trainable=self.trainable,
name='wavelet_width',
dtype=tf.float32)
super().__init__(n_scales, border_crop, stride, name)
def _build_wavelet_bank(self):
with tf.variable_scope("cmorlet_bank"):
# Generate the wavelets
# We will make a bigger wavelet in case the width grows
# For the size of the wavelet we use the initial width value.
# |t| < truncation_size => |k| < truncation_size * fs
truncation_size = self.scales.max() * np.sqrt(4.5 * self.initial_wavelet_width) * self.fs
one_side = int(self.size_factor * truncation_size)
kernel_size = 2 * one_side + 1
k_array = np.arange(kernel_size, dtype=np.float32) - one_side
t_array = k_array / self.fs # Time units
# Wavelet bank shape: 1, kernel_size, 1, n_scales
wavelet_bank_real = []
wavelet_bank_imag = []
for scale in self.scales:
norm_constant = tf.sqrt(np.pi * self.wavelet_width) * scale * self.fs / 2.0
scaled_t = t_array / scale
exp_term = tf.exp(-(scaled_t ** 2) / self.wavelet_width)
kernel_base = exp_term / norm_constant
kernel_real = kernel_base * np.cos(2 * np.pi * scaled_t)
kernel_imag = kernel_base * np.sin(2 * np.pi * scaled_t)
wavelet_bank_real.append(kernel_real)
wavelet_bank_imag.append(kernel_imag)
# Stack wavelets (shape = kernel_size, n_scales)
wavelet_bank_real = tf.stack(wavelet_bank_real, axis=-1)
wavelet_bank_imag = tf.stack(wavelet_bank_imag, axis=-1)
# Give it proper shape for convolutions
# -> shape: 1, kernel_size, n_scales
wavelet_bank_real = tf.expand_dims(wavelet_bank_real, axis=0)
wavelet_bank_imag = tf.expand_dims(wavelet_bank_imag, axis=0)
# -> shape: 1, kernel_size, 1, n_scales
wavelet_bank_real = tf.expand_dims(wavelet_bank_real, axis=2)
wavelet_bank_imag = tf.expand_dims(wavelet_bank_imag, axis=2)
return wavelet_bank_real, wavelet_bank_imag