-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathtrain_kernel_fda.m
159 lines (135 loc) · 5.29 KB
/
train_kernel_fda.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
function cf = train_kernel_fda(param,X,clabel)
% Trains a kernel Fisher Discriminant Analysis (KFDA). Works with an
% arbitrary number of classes. For a linear kernel, it is equivalent to
% LDA.
%
% Usage:
% cf = train_kernel_fda(param,X,clabel)
%
%Parameters:
% X - [samples x features] matrix of training samples -OR-
% [samples x samples] kernel matrix
% clabel - [samples x 1] vector of class labels
%
% param - struct with hyperparameters:
% .reg - type of regularization
% 'shrink': shrinkage regularization using (1-lambda)*N +
% lambda*nu*I, where nu = trace(N)/P and P =
% number of samples. nu assures that the trace of
% N is equal to the trace of the regularization
% term.
% 'ridge': ridge-type regularization of N + lambda*I,
% where N is the dual within-class scatter matrix
% and I is the identity matrix
% (default 'shrink')
% .lambda - if reg='shrink', the regularization parameter ranges
% from 0 to 1 (where 0=no regularization and 1=maximum
% regularization). (default 10^-5)
% .kernel - kernel function:
% 'linear' - linear kernel ker(x,y) = x' y
% 'rbf' - radial basis function or Gaussian kernel
% ker(x,y) = exp(-gamma * |x-y|^2);
% 'polynomial' - polynomial kernel
% ker(x,y) = (gamma * x * y' + coef0)^degree
% Alternatively, a custom kernel can be provided if there
% is a function called *_kernel is in the MATLAB path,
% where "*" is the name of the kernel (e.g. rbf_kernel).
%
% If a precomputed kernel matrix is provided as X, set
% param.kernel = 'precomputed'.
%
% HYPERPARAMETERS for specific kernels:
%
% gamma - (kernel: rbf, polynomial) controls the 'width' of the
% kernel. If set to 'auto', gamma is set to 1/(nr of features)
% (default 'auto')
% coef0 - (kernel: polynomial) constant added to the polynomial
% term in the polynomial kernel. If 0, the kernel is
% homogenous (default 1)
% degree - (kernel: polynomial) degree of the polynomial term. A too
% high degree makes overfitting likely (default 2)
%
% IMPLEMENTATION DETAILS:
% The notation in Mika et al is used below, see also wikipedia page:
% https://en.wikipedia.org/wiki/Kernel_Fisher_discriminant_analysis#Kernel_trick_with_LDA
%
% REFERENCE:
% Mika S, Raetsch G, Weston J, Schoelkopf B, Mueller KR (1999).
% Fisher discriminant analysis with kernels. Neural Networks for Signal
% Processing. IX: 41–48.
% (c) Matthias Treder
% not currently used (since we regularize N):
% kernel_regularization - regularization parameter for the kernel matrix. The
% kernel matrix K is replaced by K + kernel_regularization*I where I
% is the identity matrix (default 10e-10)
nclasses = max(clabel);
nsamples = size(X,1);
% Number of samples per class
l = arrayfun(@(c) sum(clabel == c), 1:nclasses);
% indicates whether kernel matrix has been precomputed
is_precomputed = strcmp(param.kernel,'precomputed');
%% Set kernel hyperparameter defaults
if ischar(param.gamma) && strcmp(param.gamma,'auto') && ~is_precomputed
param.gamma = 1/size(X,2);
end
%% Compute kernel
if is_precomputed
K = X;
else
kernelfun = eval(['@' param.kernel '_kernel']); % Kernel function
K = kernelfun(param, X); % Compute kernel matrix
% % Regularize
% if param.regularize_kernel > 0
% K = K + param.regularize_kernel * eye(size(X,1));
% end
end
%% N: "Dual" of within-class scatter matrix
N = zeros(nsamples);
% Get indices of samples for each class
ix = arrayfun( @(c) clabel==c, 1:nclasses,'Un',0);
for c=1:nclasses
N = N + K(:,ix{c}) * (eye(l(c)) - 1/l(c)) * K(ix{c},:);
end
%% Regularization of N
lambda = param.lambda;
if strcmp(param.reg,'shrink')
% SHRINKAGE REGULARIZATION
% We write the regularized scatter matrix as a convex combination of
% the N and an identity matrix scaled to have the same trace as N
N = (1-lambda)* N + lambda * eye(nsamples) * trace(N)/nsamples;
else
% RIDGE REGULARIZATION
% The ridge lambda must be provided directly as a positive number
N = N + lambda * eye(nsamples);
end
%% M: "Dual" of between-classes scatter matrix
% Get class-wise means
Mj = zeros(nsamples,nclasses);
for c=1:nclasses
Mj(:,c) = mean( K(:, ix{c}), 2);
end
% Sample mean
Ms = mean(K,2);
% Calculate M
M = zeros(nsamples);
for c=1:nclasses
M = M + l(c) * (Mj(:,c)-Ms) * (Mj(:,c)-Ms)';
end
%% Calculate A (matrix of alpha's)
[A,~] = eigs( N\M, nclasses-1);
%% Set up classifier struct
cf = [];
cf.kernel = param.kernel;
cf.A = A;
cf.nclasses = nclasses;
if ~is_precomputed
cf.kernelfun = kernelfun;
cf.Xtrain = X;
end
% Save projected class centroids
cf.class_means = Mj'*A;
% Hyperparameters
cf.gamma = param.gamma;
cf.coef0 = param.coef0;
cf.degree = param.degree;
end