19
19
#include < nda/nda.hpp>
20
20
#include < nda/blas.hpp>
21
21
22
-
23
22
namespace cppdlr {
24
- using dcomplex = std::complex<double >;
23
+
24
+ /* *
25
+ * Calculate the squared norm of a vector
26
+ *
27
+ * @param v The input vector
28
+ * @return x The squared norm of the vector
29
+ */
30
+ double normsq (nda::MemoryVector auto const &v) { return nda::real (nda::blas::dotc (v, v)); }
25
31
26
32
/* *
27
33
* Class constructor for barycheb: barycentric Lagrange interpolation at
@@ -116,10 +122,10 @@ namespace cppdlr {
116
122
// Compute norms of rows of input matrix, and rescale eps tolerance
117
123
auto norms = nda::vector<double >(m);
118
124
double epssq = eps * eps;
119
- for (int j = 0 ; j < m; ++j) { norms (j) = nda::real ( nda::blas::dotc ( aa (j, _), aa (j, _) )); }
125
+ for (int j = 0 ; j < m; ++j) { norms (j) = normsq ( aa (j, _)); }
120
126
121
127
// Begin pivoted double Gram-Schmidt procedure
122
- int jpiv = 0 , jj = 0 ;
128
+ int jpiv = 0 ;
123
129
double nrm = 0 ;
124
130
auto piv = nda::arange (m);
125
131
auto tmp = nda::vector<S>(n);
@@ -137,38 +143,29 @@ namespace cppdlr {
137
143
}
138
144
139
145
// Swap current row with chosen pivot row
140
- tmp = aa (j, _);
141
- aa (j, _) = aa (jpiv, _);
142
- aa (jpiv, _) = tmp;
143
-
144
- nrm = norms (j);
145
- norms (j) = norms (jpiv);
146
- norms (jpiv) = nrm;
147
-
148
- jj = piv (j);
149
- piv (j) = piv (jpiv);
150
- piv (jpiv) = jj;
146
+ deep_swap (aa (j, _), aa (jpiv, _));
147
+ std::swap (norms (j), norms (jpiv));
148
+ std::swap (piv (j), piv (jpiv));
151
149
152
150
// Orthogonalize current rows (now the chosen pivot row) against all
153
151
// previously chosen rows
154
152
for (int k = 0 ; k < j; ++k) { aa (j, _) = aa (j, _) - aa (k, _) * nda::blas::dotc (aa (k, _), aa (j, _)); }
155
153
156
154
// Get norm of current row
157
- nrm = nda::real (nda::blas::dotc (aa (j, _), aa (j, _)));
158
- // nrm = nda::norm(aa(j, _));
155
+ nrm = normsq (aa (j, _));
159
156
160
157
// Terminate if sufficiently small, and return previously selected rows
161
158
// (not including current row)
162
159
if (nrm <= epssq) { return {aa (nda::range (0 , j), _), norms (nda::range (0 , j)), piv (nda::range (0 , j))}; };
163
160
164
161
// Normalize current row
165
- aa (j, _) = aa (j, _) * ( 1 / sqrt (nrm) );
162
+ aa (j, _) /= sqrt (nrm);
166
163
167
164
// Orthogonalize remaining rows against current row
168
165
for (int k = j + 1 ; k < m; ++k) {
169
166
if (norms (k) <= epssq) { continue ; } // Can skip rows with norm less than tolerance
170
167
aa (k, _) = aa (k, _) - aa (j, _) * nda::blas::dotc (aa (j, _), aa (k, _));
171
- norms (k) = nda::real ( nda::blas::dotc ( aa (k, _), aa (k, _) ));
168
+ norms (k) = normsq ( aa (k, _));
172
169
}
173
170
}
174
171
@@ -211,22 +208,21 @@ namespace cppdlr {
211
208
if (m % 2 != 0 ) { throw std::runtime_error (" Input matrix must have even number of rows." ); }
212
209
213
210
// Copy input data, re-ordering rows to make symmetric rows adjacent.
214
- auto aa = typename T::regular_type (m, n);
211
+ auto aa = typename T::regular_type (m, n);
215
212
aa (nda::range (0 , m, 2 ), _) = a (nda::range (0 , m / 2 ), _);
216
213
aa (nda::range (1 , m, 2 ), _) = a (nda::range (m - 1 , m / 2 - 1 , -1 ), _);
217
214
218
215
// Compute norms of rows of input matrix, and rescale eps tolerance
219
216
auto norms = nda::vector<double >(m);
220
217
double epssq = eps * eps;
221
- for (int j = 0 ; j < m; ++j) { norms (j) = nda::real ( nda::blas::dotc ( aa (j, _), aa (j, _) )); }
218
+ for (int j = 0 ; j < m; ++j) { norms (j) = normsq ( aa (j, _)); }
222
219
223
220
// Begin pivoted double Gram-Schmidt procedure
224
- int jpiv = 0 , jj = 0 ;
225
- double nrm = 0 ;
226
- auto piv = nda::arange (0 , m);
221
+ int jpiv = 0 ;
222
+ double nrm = 0 ;
223
+ auto piv = nda::arange (0 , m);
227
224
piv (nda::range (0 , m, 2 )) = nda::arange (0 , m / 2 ); // Re-order pivots to match re-ordered input matrix
228
225
piv (nda::range (1 , m, 2 )) = nda::arange (m - 1 , m / 2 - 1 , -1 );
229
- auto tmp = nda::vector<S>(n);
230
226
231
227
if (maxrnk % 2 != 0 ) { // If n < m and n is odd, decrease maxrnk to maintain symmetry
232
228
maxrnk -= 1 ;
@@ -245,61 +241,46 @@ namespace cppdlr {
245
241
}
246
242
247
243
// Swap current row pair with chosen pivot row pair
248
- tmp = aa (j, _);
249
- aa (j, _) = aa (jpiv, _);
250
- aa (jpiv, _) = tmp;
251
- tmp = aa (j + 1 , _);
252
- aa (j + 1 , _) = aa (jpiv + 1 , _);
253
- aa (jpiv + 1 , _) = tmp;
254
-
255
- nrm = norms (j);
256
- norms (j) = norms (jpiv);
257
- norms (jpiv) = nrm;
258
- nrm = norms (j + 1 );
259
- norms (j + 1 ) = norms (jpiv + 1 );
260
- norms (jpiv + 1 ) = nrm;
261
-
262
- jj = piv (j);
263
- piv (j) = piv (jpiv);
264
- piv (jpiv) = jj;
265
- jj = piv (j + 1 );
266
- piv (j + 1 ) = piv (jpiv + 1 );
267
- piv (jpiv + 1 ) = jj;
244
+ deep_swap (aa (j, _), aa (jpiv, _));
245
+ deep_swap (aa (j + 1 , _), aa (jpiv + 1 , _));
246
+ std::swap (norms (j), norms (jpiv));
247
+ std::swap (norms (j + 1 ), norms (jpiv + 1 ));
248
+ std::swap (piv (j), piv (jpiv));
249
+ std::swap (piv (j + 1 ), piv (jpiv + 1 ));
268
250
269
251
// Orthogonalize current row (now the first chosen pivot row) against all
270
252
// previously chosen rows
271
253
for (int k = 0 ; k < j; ++k) { aa (j, _) = aa (j, _) - aa (k, _) * nda::blas::dotc (aa (k, _), aa (j, _)); }
272
254
273
255
// Get norm of current row
274
- nrm = nda::real ( nda::blas::dotc ( aa (j, _), aa (j, _) ));
256
+ nrm = normsq ( aa (j, _));
275
257
276
258
// Terminate if sufficiently small, and return previously selected rows
277
259
// (not including current row)
278
260
if (nrm <= epssq) { return {aa (nda::range (0 , j), _), norms (nda::range (0 , j)), piv (nda::range (0 , j))}; };
279
261
280
262
// Normalize current row
281
- aa (j, _) = aa (j, _) * ( 1 / sqrt (nrm) );
263
+ aa (j, _) /= sqrt (nrm);
282
264
283
265
// Orthogonalize remaining rows against current row
284
266
for (int k = j + 1 ; k < m; ++k) {
285
267
if (norms (k) <= epssq) { continue ; } // Can skip rows with norm less than tolerance
286
268
aa (k, _) = aa (k, _) - aa (j, _) * nda::blas::dotc (aa (j, _), aa (k, _));
287
- norms (k) = nda::real ( nda::blas::dotc ( aa (k, _), aa (k, _) ));
269
+ norms (k) = normsq ( aa (k, _));
288
270
}
289
271
290
272
// Orthogonalize current row (now the second chosen pivot row) against all
291
273
// previously chosen rows
292
274
for (int k = 0 ; k < j + 1 ; ++k) { aa (j + 1 , _) = aa (j + 1 , _) - aa (k, _) * nda::blas::dotc (aa (k, _), aa (j + 1 , _)); }
293
275
294
276
// Normalize current row
295
- nrm = nda::real (nda::blas::dotc (aa (j + 1 , _), aa (j + 1 , _)));
296
- aa (j + 1 , _) = aa (j + 1 , _) * (1 / sqrt (nrm));
277
+ aa (j + 1 , _) /= sqrt (normsq (aa (j + 1 , _)));
297
278
298
279
// Orthogonalize remaining rows against current row
299
280
for (int k = j + 2 ; k < m; ++k) {
300
281
if (norms (k) <= epssq) { continue ; } // Can skip rows with norm less than tolerance
301
282
aa (k, _) = aa (k, _) - aa (j + 1 , _) * nda::blas::dotc (aa (j + 1 , _), aa (k, _));
302
- norms (k) = nda::real ( nda::blas::dotc ( aa (k, _), aa (k, _) ));
283
+ norms (k) = normsq ( aa (k, _));
303
284
}
304
285
}
305
286
@@ -352,18 +333,18 @@ namespace cppdlr {
352
333
aa (nda::range (0 , m, 2 ), _) = a (nda::range (0 , m / 2 ), _);
353
334
aa (nda::range (1 , m, 2 ), _) = a (nda::range (m - 1 , m / 2 - 1 , -1 ), _);
354
335
} else {
355
- aa (0 , _) = a ((m - 1 ) / 2 , _);
336
+ aa (0 , _) = a ((m - 1 ) / 2 , _);
356
337
aa (nda::range (1 , m, 2 ), _) = a (nda::range (0 , (m - 1 ) / 2 ), _);
357
338
aa (nda::range (2 , m, 2 ), _) = a (nda::range (m - 1 , (m - 1 ) / 2 , -1 ), _);
358
339
// aa(m - 1, _) = a((m - 1) / 2, _);
359
340
}
360
341
361
342
// Compute norms of rows of input matrix
362
343
auto norms = nda::vector<double >(m);
363
- for (int j = 0 ; j < m; ++j) { norms (j) = nda::real ( nda::blas::dotc ( aa (j, _), aa (j, _) )); }
344
+ for (int j = 0 ; j < m; ++j) { norms (j) = normsq ( aa (j, _)); }
364
345
365
346
// Begin pivoted double Gram-Schmidt procedure
366
- int jpiv = 0 , jj = 0 ;
347
+ int jpiv = 0 ;
367
348
double nrm = 0 ;
368
349
auto piv = nda::arange (0 , m);
369
350
if (m % 2 == 0 ) {
@@ -375,23 +356,17 @@ namespace cppdlr {
375
356
piv (nda::range (2 , m, 2 )) = nda::arange (m - 1 , (m - 1 ) / 2 , -1 );
376
357
// piv(m - 1) = (m - 1) / 2;
377
358
}
378
- auto tmp = nda::vector<S>(n);
379
359
380
360
// If m odd, first choose middle row (now last row) as first pivot
381
361
382
362
if (m % 2 == 1 ) {
383
- // int j = 0; // Index of current row
384
- // jpiv = 0; // Index of pivot row
385
-
386
363
// Normalize
387
- nrm = nda::real (nda::blas::dotc (aa (0 , _), aa (0 , _)));
388
- aa (0 , _) = aa (0 , _) * (1 / sqrt (nrm));
389
- // aa(0, _) /= sqrt(nda::real(nda::blas::dotc(aa(0, _), aa(0, _))));
364
+ aa (0 , _) /= sqrt (normsq (aa (0 , _)));
390
365
391
366
// Orthogonalize remaining rows against current row
392
367
for (int k = 1 ; k < m; ++k) {
393
368
aa (k, _) = aa (k, _) - aa (0 , _) * nda::blas::dotc (aa (0 , _), aa (k, _));
394
- norms (k) = nda::real ( nda::blas::dotc ( aa (k, _), aa (k, _) ));
369
+ norms (k) = normsq ( aa (k, _));
395
370
}
396
371
}
397
372
@@ -410,53 +385,37 @@ namespace cppdlr {
410
385
}
411
386
412
387
// Swap current row pair with chosen pivot row pair
413
- tmp = aa (j, _);
414
- aa (j, _) = aa (jpiv, _);
415
- aa (jpiv, _) = tmp;
416
- tmp = aa (j + 1 , _);
417
- aa (j + 1 , _) = aa (jpiv + 1 , _);
418
- aa (jpiv + 1 , _) = tmp;
419
-
420
- nrm = norms (j);
421
- norms (j) = norms (jpiv);
422
- norms (jpiv) = nrm;
423
- nrm = norms (j + 1 );
424
- norms (j + 1 ) = norms (jpiv + 1 );
425
- norms (jpiv + 1 ) = nrm;
426
-
427
- jj = piv (j);
428
- piv (j) = piv (jpiv);
429
- piv (jpiv) = jj;
430
- jj = piv (j + 1 );
431
- piv (j + 1 ) = piv (jpiv + 1 );
432
- piv (jpiv + 1 ) = jj;
388
+ deep_swap (aa (j, _), aa (jpiv, _));
389
+ deep_swap (aa (j + 1 , _), aa (jpiv + 1 , _));
390
+ std::swap (norms (j), norms (jpiv));
391
+ std::swap (norms (j + 1 ), norms (jpiv + 1 ));
392
+ std::swap (piv (j), piv (jpiv));
393
+ std::swap (piv (j + 1 ), piv (jpiv + 1 ));
433
394
434
395
// Orthogonalize current row (now the first chosen pivot row) against all
435
396
// previously chosen rows
436
397
for (int k = 0 ; k < j; ++k) { aa (j, _) = aa (j, _) - aa (k, _) * nda::blas::dotc (aa (k, _), aa (j, _)); }
437
398
438
399
// Normalize current row
439
- nrm = nda::real (nda::blas::dotc (aa (j, _), aa (j, _)));
440
- aa (j, _) = aa (j, _) * (1 / sqrt (nrm));
400
+ aa (j, _) /= sqrt (normsq (aa (j, _)));
441
401
442
402
// Orthogonalize remaining rows against current row
443
403
for (int k = j + 1 ; k < m; ++k) {
444
404
aa (k, _) = aa (k, _) - aa (j, _) * nda::blas::dotc (aa (j, _), aa (k, _));
445
- norms (k) = nda::real ( nda::blas::dotc ( aa (k, _), aa (k, _) ));
405
+ norms (k) = normsq ( aa (k, _));
446
406
}
447
407
448
408
// Orthogonalize current row (now the second chosen pivot row) against all
449
409
// previously chosen rows
450
410
for (int k = 0 ; k < j + 1 ; ++k) { aa (j + 1 , _) = aa (j + 1 , _) - aa (k, _) * nda::blas::dotc (aa (k, _), aa (j + 1 , _)); }
451
411
452
412
// Normalize current row
453
- nrm = nda::real (nda::blas::dotc (aa (j + 1 , _), aa (j + 1 , _)));
454
- aa (j + 1 , _) = aa (j + 1 , _) * (1 / sqrt (nrm));
413
+ aa (j + 1 , _) /= sqrt (normsq (aa (j + 1 , _)));
455
414
456
415
// Orthogonalize remaining rows against current row
457
416
for (int k = j + 2 ; k < m; ++k) {
458
417
aa (k, _) = aa (k, _) - aa (j + 1 , _) * nda::blas::dotc (aa (j + 1 , _), aa (k, _));
459
- norms (k) = nda::real ( nda::blas::dotc ( aa (k, _), aa (k, _) ));
418
+ norms (k) = normsq ( aa (k, _));
460
419
}
461
420
}
462
421
@@ -551,7 +510,7 @@ namespace cppdlr {
551
510
* @return Contraction of the inner dimensions of \p a and \p b
552
511
*/
553
512
template <nda::MemoryArray Ta, nda::MemoryArray Tb, nda::Scalar Sa = nda::get_value_t <Ta>, nda::Scalar Sb = nda::get_value_t <Tb>,
554
- nda::Scalar S = typename std::common_type <Sa, Sb>::type >
513
+ nda::Scalar S = std::common_type_t <Sa, Sb>>
555
514
nda::array<S, Ta::rank + Tb::rank - 2 > arraymult (Ta const &a, Tb const &b) {
556
515
557
516
// Get ranks of input arrays
0 commit comments