-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathacf_est.cpp
366 lines (325 loc) · 10.6 KB
/
acf_est.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
/**
* Compile with the following in matlab:
* mex CXXFLAGS='$CXXFLAGS -std=c++1z -O3 -march=native -Wall -Wextra
* -Wpedantic' acf_est.cpp
*/
#include <algorithm>
#include <atomic>
#include <cmath>
#include <iostream>
#include <sstream>
#include <thread>
#include <vector>
// Matlab mex headers
#include "matrix.h"
#include "mex.h"
// Agner Fog's vectorclass library
#include "vectorclass/vectorclass.h"
// TODO list:
// * Add debug mode that printouts and sets SINGLE_THREAD_MODE to 1?
// * Update vectorclass version?
// * Use vectorclass instructionset detection
// * Makefile?
/** Set to 1 to spawn 0 threads and calculate everything in main thread. May be
* useful when debugging. */
#define SINGLE_THREAD_MODE 0
/** Set to enable vectorized instructions */
#define VECTORIZATION 1
// Handle restrict keyword
#if defined __GNUC__ || defined __clang__
#define RESTRICT __restrict__
#elif defined(MSC_VER)
#define RESTRICT __restrict
#else
#define RESTRICT
#endif
// Detect instruction set and set vector types for single and double accordingly
#if INSTRSET >= 10 // AVX512VL
typedef Vec16f vec_single_t;
typedef Vec8d vec_double_t;
#elif INSTRSET >= 8 // AVX2
typedef Vec8f vec_single_t;
typedef Vec4d vec_double_t;
#elif INSTRSET == 2
typedef Vec4f vec_single_t;
typedef Vec2d vec_double_t;
#else
#define VECTORIZATION 0
typedef float vec_single_t;
typedef double vec_double_t;
#endif
/** Divide work into work items */
struct WorkItem {
/** Row start index (inclusive, 0 <= nStart < N) */
mwSize nStart;
/** Row end index (exclusive, 0 <= nEnd < N) */
mwSize nEnd;
/** Column start index (inclusive, 0 <= cStart < C) */
mwSize cStart;
/** Column end index (exclusive, 0 <= cEnd < C) */
mwSize cEnd;
};
/** Parameters to worker threads */
struct ThreadParams {
/** Number of matrix rows (Size of each ACF estimation) */
mwSize N;
/** Number of matrix columns (Number of ACFs to estimate) */
mwSize C;
/** Input matrix [N, C] */
const void* x;
/** Output matrix [2N-1, C] */
void* y;
/** Current (global) index in work queue */
std::atomic<size_t>* workQueueIdx;
/** Work queue */
std::vector<WorkItem>* workQueue;
};
void mexFunction(int nlhs, mxArray** plhs, int nrhs, const mxArray** prhs);
void checkArguments(int nlhs, mxArray** plhs, int nrhs, const mxArray** prhs);
mxArray* spawnThreads(const mxArray* vIn);
template <typename Tvec, typename Tscal>
void* calculate(const ThreadParams& p);
std::vector<WorkItem> divideWork(mwSize N, mwSize C, mwSize workItemsPerCol);
const WorkItem* nextWorkItem(const ThreadParams& p);
/**
* Matlab mex entry function. Efficient calculation of Bartlett's estimate of
* the auto correlation function.
*
* \param nlhs Number of left hand parameters
* \param plhs Left hand parameters [nlhs]
* \param nrhs Number of right hand parameters
* \param prhs Right hand parameters [nrhs]
*
*/
void mexFunction(int nlhs, mxArray** plhs, int nrhs, const mxArray** prhs) {
checkArguments(nlhs, plhs, nrhs, prhs);
mxArray* y = spawnThreads(prhs[0]);
if (nrhs >= 1)
plhs[0] = y;
}
/**
* Check that arguments in are valid
*
* \param nlhs Number of left hand parameters
* \param plhs Left hand parameters [nlhs]
* \param nrhs Number of right hand parameters
* \param prhs Right hand parameters [nrhs]
*
*/
void checkArguments(int nlhs, mxArray** plhs, int nrhs, const mxArray** prhs) {
(void)(plhs); // Unused
if (nrhs != 1)
mexErrMsgIdAndTxt("acf_est:checkArguments", "One input required");
if (!mxIsSingle(prhs[0]) && !mxIsDouble(prhs[0]))
mexErrMsgIdAndTxt("acf_est:checkArguments",
"Input matrix must be of type single or double");
if (mxIsComplex(prhs[0]))
mexErrMsgIdAndTxt("acf_est:checkArguments",
"Input matrix cannot be complex");
if (mxGetNumberOfDimensions(prhs[0]) >= 4)
mexErrMsgIdAndTxt("acf_est:checkArguments",
"Cannot handle 4-dimensional matrices or greater");
if (nlhs > 1)
mexErrMsgIdAndTxt("acf_est:checkArguments", "One or zero outputs required");
}
/**
* Spawn worker threads
*
* \param vIn Input array
* \return Output array
*
*/
mxArray* spawnThreads(const mxArray* vIn) {
// Get matrix dimensions
const mwSize* dims = mxGetDimensions(vIn);
mwSize N = dims[0];
mwSize C = dims[1];
// Create output matrix
mxClassID classId = mxUNKNOWN_CLASS;
if (mxIsSingle(vIn))
classId = mxSINGLE_CLASS;
else
classId = mxDOUBLE_CLASS;
mxArray* vOut = mxCreateNumericMatrix(N, C, classId, mxREAL);
// Ensure that the first non-singular dimension is handled
if (N == 1 && C != 1)
std::swap(N, C);
// Detect parallelism. Use all threads.
unsigned nThreads = std::thread::hardware_concurrency();
if (nThreads == 0)
nThreads = 1;
// Fill work queue
auto workItems = divideWork(N, C, nThreads);
// Set parameters passed to threads
ThreadParams params;
std::atomic<size_t> workQueueIdx = 0; /**< Current work queue index */
params.N = N;
params.C = C;
params.x = mxGetData(vIn);
params.y = mxGetData(vOut);
params.workQueue = &workItems;
params.workQueueIdx = &workQueueIdx;
// Allocate worker threads
std::vector<std::thread> threads(nThreads);
// Start all worker threads
try {
for (unsigned i = 0; i < nThreads; ++i) {
#if SINGLE_THREAD_MODE == 0
if (mxIsSingle(vIn)) {
threads[i] =
std::move(std::thread(calculate<vec_single_t, float>, params));
} else {
threads[i] =
std::move(std::thread(calculate<vec_double_t, double>, params));
}
#else
if (mxIsSingle(vIn))
calculate<vec_single_t, float>(params);
else
calculate<vec_double_t, double>(params);
#endif
}
} catch (const std::exception& ex) {
std::stringstream ss;
ss << "Failed to create thread: " << ex.what();
mexErrMsgIdAndTxt("acf_est:spawnThreads", ss.str().c_str());
}
// Wait for all worker threads to finish
#if SINGLE_THREAD_MODE == 0
try {
for (unsigned i = 0; i < nThreads; ++i)
threads[i].join();
} catch (const std::exception& ex) {
std::stringstream ss;
ss << "Failed to join thread: " << ex.what();
mexErrMsgIdAndTxt("acf_est:spawnThreads", ss.str().c_str());
}
#endif
return vOut;
}
/**
* Caluclate Bartlett's estimate
*
* \tparam Tvec Vector type to use when calculating
* \tparam Tscal Scalar type to use when calculating
* \param p Worker thread parameters
* \return NULL
*
*/
template <typename Tvec, typename Tscal>
void* calculate(const ThreadParams& p) {
// Cast data pointers
const Tscal* RESTRICT x = static_cast<const Tscal*>(p.x); /**< Input */
Tscal* RESTRICT y = static_cast<Tscal*>(p.y); /**< Output */
// Get next work item
const WorkItem* w = nullptr;
while ((w = nextWorkItem(p)) != nullptr) {
// Iterate through columns
for (mwSize c = w->cStart; c < w->cEnd; ++c) {
// Iterate through rows
for (mwSize k = w->nStart; k < w->nEnd; ++k) {
Tscal s; /**< Current sum */
#if VECTORIZATION == 0
// Zero s since summed to
s = 0.0;
// Simplest realization
for (size_t n = 0; n < p.N - k; ++n)
s += x[n] * x[n + k];
#else
int lim = (int)p.N - (int)k - Tvec::size() + 1; /**< Iteration limit */
int n; /**< Iteration index */
// Use a sum vector and do a horizontal add when finished
Tvec sv(0);
// Vectorized loop
for (n = 0; n < lim; n += Tvec::size()) {
// Read two double vectors offset by k from memory
Tvec v1 = Tvec().load(x + c * p.N + n);
Tvec v2 = Tvec().load(x + c * p.N + n + k);
// Multiply and sum
sv = mul_add(v1, v2, sv);
}
// Finished with vector operations. Sum vector elements.
s = horizontal_add(sv);
// Non-vectorized loop for the remaining vector size - 1 elements
for (; n < (int)(p.N - k); ++n)
s += x[c * p.N + n] * x[c * p.N + n + k];
#endif
// Sum is ready. Write only half of spectra due to symmetry and for
// efficency
y[c * p.N + k] = s / p.N;
}
}
}
// Always return NULL
return nullptr;
}
/**
* Divide work into work items into suitable size
*
* \param N Number of rows
* \param C Number of columns
* \param nThreads Number of worker threads
*
* \return List with work items
*/
std::vector<WorkItem> divideWork(mwSize N, mwSize C, mwSize nThreads) {
// Calculate number of work items per column as the number of
// multiplications that takes 1 ms on a CPU with nThreads cores. Assume a
// clock speed of 2GHz and a cost of floating point multiplication as 1 clock
// cycle. There are N(N+1)/2 multiplications for one column (arithmetic sum).
mwSize itemsPerCol = N * (N + 1) / (4 * nThreads * 1000000);
if (itemsPerCol == 0)
itemsPerCol = 1;
// Preallocate "heavy" floating point calculations
std::vector<std::pair<mwSize, mwSize>> lim(itemsPerCol); // Limits
for (unsigned n = 0; n < itemsPerCol; ++n) {
// Divide into itemsPerCol work items for each column
// such that all work items result in approximately the same number of
// multiplications.
mwSize a = (n == 0 ? 0 : lim[n - 1].second); // a is the previous b
double sqrtArg = std::max(
0.0, a * a - 2 * N * a + N * N -
static_cast<double>(N * N) /
itemsPerCol); // Ensure nonnegative, for valid sqare root
mwSize b = (n == itemsPerCol - 1
? N
: N - static_cast<mwSize>(std::ceil(std::sqrt(sqrtArg))));
lim[n] = std::make_pair(a, b);
}
// Make work items
std::vector<WorkItem> workItems(itemsPerCol * C);
for (unsigned c = 0; c < C; ++c) {
for (unsigned n = 0; n < itemsPerCol; ++n) {
WorkItem& w = workItems[c * itemsPerCol + n];
w.cStart = c;
w.cEnd = w.cStart + 1;
w.nStart = lim[n].first;
w.nEnd = lim[n].second;
}
}
return workItems;
}
/**
* Get a new work item, or nothing
*
* \param p Thread parameters
*
* \return Pointer to work item, or NULL
*
*/
const WorkItem* nextWorkItem(const ThreadParams& p) {
// Expected work queue index.
auto queueIdxExp = p.workQueueIdx->load(std::memory_order_relaxed);
// Use CAS for speed
while (!p.workQueueIdx->compare_exchange_weak(queueIdxExp, queueIdxExp + 1,
std::memory_order_release,
std::memory_order_relaxed)) {
// Do nothing
}
// Check if finished
if (queueIdxExp >= p.workQueue->size()) {
return nullptr;
}
// Get unique and valid work item
return &(*p.workQueue)[queueIdxExp];
}