forked from Pim-Mostert/decoding-toolbox
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_LDA.m
76 lines (59 loc) · 1.96 KB
/
train_LDA.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
function decoder = train_LDA(cfg0, X, Y)
% [decoder] = train_LDA(cfg, X, Y)
% Trains a linear discriminant analysis decoder.
%
% X Array of length N, where N is the number of trials, that specifies the class
% label (0 or 1) for each trial.
% Y Matrix of size F x N, where F is the number of features, that contains the
% training data.
% cfg Configuration struct that can possess the following fields:
% .gamma = [scalar] Shrinkage regularization parameter, with range [0 1].
% No default given.
% .discardNan = 'yes' or 'no' Whether trials with NaN in either X or Y should be
% removed prior to training. Default is 'no'.
%
% decoder The trained decoder, that may be passed to an appropriate decoding function.
%
% See also DECODE_LDA
% Created by Pim Mostert, 2016
decoder = [];
%% Pre-process cfg-struct
if ~isfield(cfg0, 'discardNan')
cfg0.discardNan = 'no';
end
if ~isfield(cfg0, 'gamma')
warning(sprintf('No regularization (cfg.gamma) specified!\nIf this is intended, then please specifyc cfg0.gamma = 0'));
end
%% Pre-process data
X = X(:);
Y = Y';
if strcmp(cfg0.discardNan, 'yes')
iNan = isnan(X) | any(isnan(Y), 2);
X = X(~iNan);
Y = Y(~iNan, :);
end
numF = size(Y, 2);
%% Calculate decoder
% Calculate group means
m0 = mean(Y(X==0, :), 1);
m1 = mean(Y(X==1, :), 1);
% Mean difference
d = m1 - m0;
% Calculate covariances
S0 = cov(Y(X==0, :), 1);
S1 = cov(Y(X==1, :), 1);
% Pool covariance
S = 0.5*(S0 + S1);
% Regularize pooled covariance matrix
if isfield(cfg0, 'gamma')
S = (1-cfg0.gamma)*S + cfg0.gamma*eye(numF)*trace(S)/numF;
end
% Calculate weights
W = d/S;
% Normalize weights
W = W/(W*d');
%% Include overall (unweighted) mean for demeaning during decoding
decoder.mY = 0.5*(m0 + m1);
%% Return
decoder.W = W;
end