-
Notifications
You must be signed in to change notification settings - Fork 6
/
auxiva_gauss.py
146 lines (120 loc) · 5.12 KB
/
auxiva_gauss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) 2019 Robin Scheibler
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
'''
Blind Source Separation using Independent Vector Analysis with Auxiliary Function
based on time-varying Gauss distribution.
2018 (c) Robin Scheibler, MIT License
'''
import numpy as np
from pyroomacoustics import stft, istft
from pyroomacoustics.bss import projection_back
# A few contrast functions
f_contrasts = {
'norm' : { 'f' : (lambda r,c,m : c * r), 'df' : (lambda r,c,m : c) },
'cosh' : { 'f' : (lambda r,c,m : m * np.log(np.cosh(c * r))), 'df' : (lambda r,c,m : c * m * np.tanh(c * r)) }
}
def auxiva_gauss(X, n_src=None, n_iter=20, proj_back=True, W0=None,
f_contrast=None, f_contrast_args=[],
return_filters=False, callback=None):
'''
Implementation of AuxIVA algorithm for BSS presented in
N. Ono, *Stable and fast update rules for independent vector analysis based
on auxiliary function technique*, Proc. IEEE, WASPAA, pp. 189-192, September, 2011.
This version uses time-varying Gauss source model.
Parameters
----------
X: ndarray (nframes, nfrequencies, nchannels)
STFT representation of the signal
n_src: int, optional
The number of sources or independent components
n_iter: int, optional
The number of iterations (default 20)
proj_back: bool, optional
Scaling on first mic by back projection (default True)
W0: ndarray (nfrequencies, nchannels, nchannels), optional
Initial value for demixing matrix
f_contrast: dict of functions
A dictionary with two elements 'f' and 'df' containing the contrast
function taking 3 arguments This should be a ufunc acting element-wise
on any array
return_filters: bool
If true, the function will return the demixing matrix too
callback: func
A callback function called every 10 iterations, allows to monitor convergence
Returns
-------
Returns an (nframes, nfrequencies, nsources) array. Also returns
the demixing matrix (nfrequencies, nchannels, nsources)
if ``return_values`` keyword is True.
'''
n_frames, n_freq, n_chan = X.shape
# default to determined case
if n_src is None:
n_src = X.shape[2]
# for now, only supports determined case
assert n_chan == n_src
# initialize the demixing matrices
if W0 is None:
W = np.array([np.eye(n_chan, n_src) for f in range(n_freq)], dtype=X.dtype)
else:
W = W0.copy()
if f_contrast is None:
f_contrast = f_contrasts['norm']
f_contrast_args = [1, 1]
I = np.eye(n_src,n_src)
Y = np.zeros((n_frames, n_freq, n_src), dtype=X.dtype)
V = np.zeros((n_freq, n_src, n_chan, n_chan), dtype=X.dtype)
r = np.zeros((n_frames, n_src))
G_r = np.zeros((n_frames, n_src))
# Compute the demixed output
def demix(Y, X, W):
for f in range(n_freq):
Y[:,f,:] = np.dot(X[:,f,:], np.conj(W[f,:,:]))
for epoch in range(n_iter):
demix(Y, X, W)
if callback is not None and epoch % 10 == 0:
if proj_back:
z = projection_back(Y, X[:,:,0])
callback(Y * np.conj(z[None,:,:]))
else:
callback(Y)
# simple loop as a start
# shape: (n_frames, n_src)
r[:,:] = np.mean(np.abs(Y * np.conj(Y)), axis=1)
# Apply derivative of contrast function
G_r[:,:] = 1. / r / 2. # shape (n_frames, n_src)
# Compute Auxiliary Variable
for f in range(n_freq):
for s in range(n_src):
V[f,s,:,:] = (np.dot(G_r[None,:,s] * X[:,f,:].T, np.conj(X[:,f,:]))) / X.shape[0]
# Update now the demixing matrix
for f in range(n_freq):
for s in range(n_src):
WV = np.dot(np.conj(W[f,:,:].T), V[f,s,:,:])
W[f,:,s] = np.linalg.solve(WV, I[:,s])
W[f,:,s] /= np.sqrt(np.inner(np.conj(W[f,:,s]), np.dot(V[f,s,:,:], W[f,:,s])))
demix(Y, X, W)
if proj_back:
z = projection_back(Y, X[:,:,0])
Y *= np.conj(z[None,:,:])
if return_filters:
return Y, W
else:
return Y