-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstft.py
71 lines (57 loc) · 2.05 KB
/
stft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Copied from https://github.com/wangtianrui/DCCRN/blob/4f961fcb3e431e3d2d4393b40532fe982692b45c/utils/conv_stft.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import get_window
def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
if win_type == "None" or win_type is None:
window = np.ones(win_len)
else:
window = get_window(win_type, win_len, fftbins=True) ** 0.5
N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T
if invers:
kernel = np.linalg.pinv(kernel).T
kernel = kernel * window
kernel = kernel[:, None, :]
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(
window[None, :, None].astype(np.float32)
)
class ConvSTFT(nn.Module):
def __init__(
self,
win_len,
win_inc,
fft_len=None,
win_type="hamming",
feature_type="real",
fix=True,
):
super(ConvSTFT, self).__init__()
if fft_len is None:
self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
self.weight = nn.Parameter(kernel, requires_grad=(not fix))
self.feature_type = feature_type
self.stride = win_inc
self.win_len = win_len
self.dim = self.fft_len
def forward(self, inputs):
if inputs.dim() == 2:
inputs = torch.unsqueeze(inputs, 1)
outputs = F.conv1d(inputs, self.weight, stride=self.stride)
if self.feature_type == "complex":
return outputs
else:
dim = self.dim // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
mags = torch.sqrt(real**2 + imag**2)
phase = torch.atan2(imag, real)
return mags, phase