Skip to content

Commit 3cd37d0

Browse files
committed
Various simplifications in cppdlr/utils.hpp
1 parent 1660ab3 commit 3cd37d0

File tree

1 file changed

+48
-89
lines changed

1 file changed

+48
-89
lines changed

c++/cppdlr/utils.hpp

+48-89
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,15 @@
1919
#include <nda/nda.hpp>
2020
#include <nda/blas.hpp>
2121

22-
2322
namespace cppdlr {
24-
using dcomplex = std::complex<double>;
23+
24+
/**
25+
* Calculate the squared norm of a vector
26+
*
27+
* @param v The input vector
28+
* @return x The squared norm of the vector
29+
*/
30+
double normsq(nda::MemoryVector auto const &v) { return nda::real(nda::blas::dotc(v, v)); }
2531

2632
/**
2733
* Class constructor for barycheb: barycentric Lagrange interpolation at
@@ -116,10 +122,10 @@ namespace cppdlr {
116122
// Compute norms of rows of input matrix, and rescale eps tolerance
117123
auto norms = nda::vector<double>(m);
118124
double epssq = eps * eps;
119-
for (int j = 0; j < m; ++j) { norms(j) = nda::real(nda::blas::dotc(aa(j, _), aa(j, _))); }
125+
for (int j = 0; j < m; ++j) { norms(j) = normsq(aa(j, _)); }
120126

121127
// Begin pivoted double Gram-Schmidt procedure
122-
int jpiv = 0, jj = 0;
128+
int jpiv = 0;
123129
double nrm = 0;
124130
auto piv = nda::arange(m);
125131
auto tmp = nda::vector<S>(n);
@@ -137,38 +143,29 @@ namespace cppdlr {
137143
}
138144

139145
// Swap current row with chosen pivot row
140-
tmp = aa(j, _);
141-
aa(j, _) = aa(jpiv, _);
142-
aa(jpiv, _) = tmp;
143-
144-
nrm = norms(j);
145-
norms(j) = norms(jpiv);
146-
norms(jpiv) = nrm;
147-
148-
jj = piv(j);
149-
piv(j) = piv(jpiv);
150-
piv(jpiv) = jj;
146+
deep_swap(aa(j, _), aa(jpiv, _));
147+
std::swap(norms(j), norms(jpiv));
148+
std::swap(piv(j), piv(jpiv));
151149

152150
// Orthogonalize current rows (now the chosen pivot row) against all
153151
// previously chosen rows
154152
for (int k = 0; k < j; ++k) { aa(j, _) = aa(j, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j, _)); }
155153

156154
// Get norm of current row
157-
nrm = nda::real(nda::blas::dotc(aa(j, _), aa(j, _)));
158-
//nrm = nda::norm(aa(j, _));
155+
nrm = normsq(aa(j, _));
159156

160157
// Terminate if sufficiently small, and return previously selected rows
161158
// (not including current row)
162159
if (nrm <= epssq) { return {aa(nda::range(0, j), _), norms(nda::range(0, j)), piv(nda::range(0, j))}; };
163160

164161
// Normalize current row
165-
aa(j, _) = aa(j, _) * (1 / sqrt(nrm));
162+
aa(j, _) /= sqrt(nrm);
166163

167164
// Orthogonalize remaining rows against current row
168165
for (int k = j + 1; k < m; ++k) {
169166
if (norms(k) <= epssq) { continue; } // Can skip rows with norm less than tolerance
170167
aa(k, _) = aa(k, _) - aa(j, _) * nda::blas::dotc(aa(j, _), aa(k, _));
171-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
168+
norms(k) = normsq(aa(k, _));
172169
}
173170
}
174171

@@ -211,22 +208,21 @@ namespace cppdlr {
211208
if (m % 2 != 0) { throw std::runtime_error("Input matrix must have even number of rows."); }
212209

213210
// Copy input data, re-ordering rows to make symmetric rows adjacent.
214-
auto aa = typename T::regular_type(m, n);
211+
auto aa = typename T::regular_type(m, n);
215212
aa(nda::range(0, m, 2), _) = a(nda::range(0, m / 2), _);
216213
aa(nda::range(1, m, 2), _) = a(nda::range(m - 1, m / 2 - 1, -1), _);
217214

218215
// Compute norms of rows of input matrix, and rescale eps tolerance
219216
auto norms = nda::vector<double>(m);
220217
double epssq = eps * eps;
221-
for (int j = 0; j < m; ++j) { norms(j) = nda::real(nda::blas::dotc(aa(j, _), aa(j, _))); }
218+
for (int j = 0; j < m; ++j) { norms(j) = normsq(aa(j, _)); }
222219

223220
// Begin pivoted double Gram-Schmidt procedure
224-
int jpiv = 0, jj = 0;
225-
double nrm = 0;
226-
auto piv = nda::arange(0, m);
221+
int jpiv = 0;
222+
double nrm = 0;
223+
auto piv = nda::arange(0, m);
227224
piv(nda::range(0, m, 2)) = nda::arange(0, m / 2); // Re-order pivots to match re-ordered input matrix
228225
piv(nda::range(1, m, 2)) = nda::arange(m - 1, m / 2 - 1, -1);
229-
auto tmp = nda::vector<S>(n);
230226

231227
if (maxrnk % 2 != 0) { // If n < m and n is odd, decrease maxrnk to maintain symmetry
232228
maxrnk -= 1;
@@ -245,61 +241,46 @@ namespace cppdlr {
245241
}
246242

247243
// Swap current row pair with chosen pivot row pair
248-
tmp = aa(j, _);
249-
aa(j, _) = aa(jpiv, _);
250-
aa(jpiv, _) = tmp;
251-
tmp = aa(j + 1, _);
252-
aa(j + 1, _) = aa(jpiv + 1, _);
253-
aa(jpiv + 1, _) = tmp;
254-
255-
nrm = norms(j);
256-
norms(j) = norms(jpiv);
257-
norms(jpiv) = nrm;
258-
nrm = norms(j + 1);
259-
norms(j + 1) = norms(jpiv + 1);
260-
norms(jpiv + 1) = nrm;
261-
262-
jj = piv(j);
263-
piv(j) = piv(jpiv);
264-
piv(jpiv) = jj;
265-
jj = piv(j + 1);
266-
piv(j + 1) = piv(jpiv + 1);
267-
piv(jpiv + 1) = jj;
244+
deep_swap(aa(j, _), aa(jpiv, _));
245+
deep_swap(aa(j + 1, _), aa(jpiv + 1, _));
246+
std::swap(norms(j), norms(jpiv));
247+
std::swap(norms(j + 1), norms(jpiv + 1));
248+
std::swap(piv(j), piv(jpiv));
249+
std::swap(piv(j + 1), piv(jpiv + 1));
268250

269251
// Orthogonalize current row (now the first chosen pivot row) against all
270252
// previously chosen rows
271253
for (int k = 0; k < j; ++k) { aa(j, _) = aa(j, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j, _)); }
272254

273255
// Get norm of current row
274-
nrm = nda::real(nda::blas::dotc(aa(j, _), aa(j, _)));
256+
nrm = normsq(aa(j, _));
275257

276258
// Terminate if sufficiently small, and return previously selected rows
277259
// (not including current row)
278260
if (nrm <= epssq) { return {aa(nda::range(0, j), _), norms(nda::range(0, j)), piv(nda::range(0, j))}; };
279261

280262
// Normalize current row
281-
aa(j, _) = aa(j, _) * (1 / sqrt(nrm));
263+
aa(j, _) /= sqrt(nrm);
282264

283265
// Orthogonalize remaining rows against current row
284266
for (int k = j + 1; k < m; ++k) {
285267
if (norms(k) <= epssq) { continue; } // Can skip rows with norm less than tolerance
286268
aa(k, _) = aa(k, _) - aa(j, _) * nda::blas::dotc(aa(j, _), aa(k, _));
287-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
269+
norms(k) = normsq(aa(k, _));
288270
}
289271

290272
// Orthogonalize current row (now the second chosen pivot row) against all
291273
// previously chosen rows
292274
for (int k = 0; k < j + 1; ++k) { aa(j + 1, _) = aa(j + 1, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j + 1, _)); }
293275

294276
// Normalize current row
295-
nrm = nda::real(nda::blas::dotc(aa(j + 1, _), aa(j + 1, _)));
296-
aa(j + 1, _) = aa(j + 1, _) * (1 / sqrt(nrm));
277+
aa(j + 1, _) /= sqrt(normsq(aa(j + 1, _)));
297278

298279
// Orthogonalize remaining rows against current row
299280
for (int k = j + 2; k < m; ++k) {
300281
if (norms(k) <= epssq) { continue; } // Can skip rows with norm less than tolerance
301282
aa(k, _) = aa(k, _) - aa(j + 1, _) * nda::blas::dotc(aa(j + 1, _), aa(k, _));
302-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
283+
norms(k) = normsq(aa(k, _));
303284
}
304285
}
305286

@@ -352,18 +333,18 @@ namespace cppdlr {
352333
aa(nda::range(0, m, 2), _) = a(nda::range(0, m / 2), _);
353334
aa(nda::range(1, m, 2), _) = a(nda::range(m - 1, m / 2 - 1, -1), _);
354335
} else {
355-
aa(0, _) = a((m - 1) / 2, _);
336+
aa(0, _) = a((m - 1) / 2, _);
356337
aa(nda::range(1, m, 2), _) = a(nda::range(0, (m - 1) / 2), _);
357338
aa(nda::range(2, m, 2), _) = a(nda::range(m - 1, (m - 1) / 2, -1), _);
358339
//aa(m - 1, _) = a((m - 1) / 2, _);
359340
}
360341

361342
// Compute norms of rows of input matrix
362343
auto norms = nda::vector<double>(m);
363-
for (int j = 0; j < m; ++j) { norms(j) = nda::real(nda::blas::dotc(aa(j, _), aa(j, _))); }
344+
for (int j = 0; j < m; ++j) { norms(j) = normsq(aa(j, _)); }
364345

365346
// Begin pivoted double Gram-Schmidt procedure
366-
int jpiv = 0, jj = 0;
347+
int jpiv = 0;
367348
double nrm = 0;
368349
auto piv = nda::arange(0, m);
369350
if (m % 2 == 0) {
@@ -375,23 +356,17 @@ namespace cppdlr {
375356
piv(nda::range(2, m, 2)) = nda::arange(m - 1, (m - 1) / 2, -1);
376357
//piv(m - 1) = (m - 1) / 2;
377358
}
378-
auto tmp = nda::vector<S>(n);
379359

380360
// If m odd, first choose middle row (now last row) as first pivot
381361

382362
if (m % 2 == 1) {
383-
//int j = 0; // Index of current row
384-
//jpiv = 0; // Index of pivot row
385-
386363
// Normalize
387-
nrm = nda::real(nda::blas::dotc(aa(0, _), aa(0, _)));
388-
aa(0, _) = aa(0, _) * (1 / sqrt(nrm));
389-
//aa(0, _) /= sqrt(nda::real(nda::blas::dotc(aa(0, _), aa(0, _))));
364+
aa(0, _) /= sqrt(normsq(aa(0, _)));
390365

391366
// Orthogonalize remaining rows against current row
392367
for (int k = 1; k < m; ++k) {
393368
aa(k, _) = aa(k, _) - aa(0, _) * nda::blas::dotc(aa(0, _), aa(k, _));
394-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
369+
norms(k) = normsq(aa(k, _));
395370
}
396371
}
397372

@@ -410,53 +385,37 @@ namespace cppdlr {
410385
}
411386

412387
// Swap current row pair with chosen pivot row pair
413-
tmp = aa(j, _);
414-
aa(j, _) = aa(jpiv, _);
415-
aa(jpiv, _) = tmp;
416-
tmp = aa(j + 1, _);
417-
aa(j + 1, _) = aa(jpiv + 1, _);
418-
aa(jpiv + 1, _) = tmp;
419-
420-
nrm = norms(j);
421-
norms(j) = norms(jpiv);
422-
norms(jpiv) = nrm;
423-
nrm = norms(j + 1);
424-
norms(j + 1) = norms(jpiv + 1);
425-
norms(jpiv + 1) = nrm;
426-
427-
jj = piv(j);
428-
piv(j) = piv(jpiv);
429-
piv(jpiv) = jj;
430-
jj = piv(j + 1);
431-
piv(j + 1) = piv(jpiv + 1);
432-
piv(jpiv + 1) = jj;
388+
deep_swap(aa(j, _), aa(jpiv, _));
389+
deep_swap(aa(j + 1, _), aa(jpiv + 1, _));
390+
std::swap(norms(j), norms(jpiv));
391+
std::swap(norms(j + 1), norms(jpiv + 1));
392+
std::swap(piv(j), piv(jpiv));
393+
std::swap(piv(j + 1), piv(jpiv + 1));
433394

434395
// Orthogonalize current row (now the first chosen pivot row) against all
435396
// previously chosen rows
436397
for (int k = 0; k < j; ++k) { aa(j, _) = aa(j, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j, _)); }
437398

438399
// Normalize current row
439-
nrm = nda::real(nda::blas::dotc(aa(j, _), aa(j, _)));
440-
aa(j, _) = aa(j, _) * (1 / sqrt(nrm));
400+
aa(j, _) /= sqrt(normsq(aa(j, _)));
441401

442402
// Orthogonalize remaining rows against current row
443403
for (int k = j + 1; k < m; ++k) {
444404
aa(k, _) = aa(k, _) - aa(j, _) * nda::blas::dotc(aa(j, _), aa(k, _));
445-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
405+
norms(k) = normsq(aa(k, _));
446406
}
447407

448408
// Orthogonalize current row (now the second chosen pivot row) against all
449409
// previously chosen rows
450410
for (int k = 0; k < j + 1; ++k) { aa(j + 1, _) = aa(j + 1, _) - aa(k, _) * nda::blas::dotc(aa(k, _), aa(j + 1, _)); }
451411

452412
// Normalize current row
453-
nrm = nda::real(nda::blas::dotc(aa(j + 1, _), aa(j + 1, _)));
454-
aa(j + 1, _) = aa(j + 1, _) * (1 / sqrt(nrm));
413+
aa(j + 1, _) /= sqrt(normsq(aa(j + 1, _)));
455414

456415
// Orthogonalize remaining rows against current row
457416
for (int k = j + 2; k < m; ++k) {
458417
aa(k, _) = aa(k, _) - aa(j + 1, _) * nda::blas::dotc(aa(j + 1, _), aa(k, _));
459-
norms(k) = nda::real(nda::blas::dotc(aa(k, _), aa(k, _)));
418+
norms(k) = normsq(aa(k, _));
460419
}
461420
}
462421

@@ -551,7 +510,7 @@ namespace cppdlr {
551510
* @return Contraction of the inner dimensions of \p a and \p b
552511
*/
553512
template <nda::MemoryArray Ta, nda::MemoryArray Tb, nda::Scalar Sa = nda::get_value_t<Ta>, nda::Scalar Sb = nda::get_value_t<Tb>,
554-
nda::Scalar S = typename std::common_type<Sa, Sb>::type>
513+
nda::Scalar S = std::common_type_t<Sa, Sb>>
555514
nda::array<S, Ta::rank + Tb::rank - 2> arraymult(Ta const &a, Tb const &b) {
556515

557516
// Get ranks of input arrays

0 commit comments

Comments
 (0)