@@ -391,12 +391,13 @@ public:
391391 T minAbs () const ;
392392
393393 /* *
394- * Solves for the least squares solution of A \ b.
395- * A is this tensor and b is the provided tensor.
394+ * Batch solves `A \ b`.
395+ * Solves `bi <- Ai \ bi` for each k-index `i`.
396+ * A is this (m,n,k)-tensor and b is the provided (m,1,k)-tensor.
396397 * A and b must have compatible dimensions (same number of rows and matrices).
397398 * A must be a square or tall matrix (m>=n).
398399 * @param b provided tensor
399- * @return least squares solution (overwrites b)
400+ * @return least squares solutions (overwrites (n,1,k)-part of b)
400401 */
401402 void leastSquaresBatched (DTensor &b);
402403
@@ -1356,8 +1357,8 @@ private:
13561357public:
13571358
13581359 QRFactoriser (DTensor<T> &A) {
1359- if (A.numMats () > 1 ) throw std::invalid_argument (" [LeastSquares ] 3D tensors require `leastSquaresBatched`" );
1360- if (A.numRows () < A.numCols ()) throw std::invalid_argument (" [Cholesky ] Matrix A must be tall or square" );
1360+ if (A.numMats () > 1 ) throw std::invalid_argument (" [QR ] 3D tensors require `leastSquaresBatched`" );
1361+ if (A.numRows () < A.numCols ()) throw std::invalid_argument (" [QR ] Matrix A must be tall or square" );
13611362 m_matrix = &A;
13621363 computeWorkspaceSize ();
13631364 m_workspace = std::make_unique<DTensor<T>>(m_workspaceSize);
@@ -1372,7 +1373,7 @@ public:
13721373 int factorise ();
13731374
13741375 /* *
1375- * Solves for the solution of A \ b using the QR of A.
1376+ * Solves A \ b using the QR of A.
13761377 * A is the matrix that is factorised and b is the provided matrix.
13771378 * A and b must have compatible dimensions (same number of rows and matrices=1).
13781379 * A must be tall or square (m>=n).
@@ -1482,15 +1483,13 @@ inline int QRFactoriser<double>::getQR(DTensor<double> &Q, DTensor<double> &R) {
14821483 size_t m = m_matrix->numRows ();
14831484 size_t n = m_matrix->numCols ();
14841485 if (Q.numRows () != m || Q.numCols () != n)
1485- throw std::invalid_argument (" [QRFactoriser ] invalid shape of Q." );
1486+ throw std::invalid_argument (" [QR ] invalid shape of Q." );
14861487 if (R.numRows () != n || R.numCols () != n)
1487- throw std::invalid_argument (" [QRFactoriser ] invalid shape of R." );
1488+ throw std::invalid_argument (" [QR ] invalid shape of R." );
14881489 // Initialize Q to 1's on diagonal
14891490 std::vector<double > vecQ (m * n, 0 .);
1490- for (int r = 0 ; r < m; r++) {
1491- for (int c = 0 ; c < n; c++) {
1492- if (r == c) { vecQ[r * n + c] = 1 .; }
1493- }
1491+ for (size_t i = 0 ; i < m; i++) {
1492+ vecQ[i * n + i] = 1 .;
14941493 }
14951494 Q.upload (vecQ, rowMajor);
14961495 // Apply Householder reflectors to compute Q
@@ -1517,15 +1516,13 @@ inline int QRFactoriser<float>::getQR(DTensor<float> &Q, DTensor<float> &R) {
15171516 size_t m = m_matrix->numRows ();
15181517 size_t n = m_matrix->numCols ();
15191518 if (Q.numRows () != m || Q.numCols () != n)
1520- throw std::invalid_argument (" [QRFactoriser ] invalid shape of Q." );
1519+ throw std::invalid_argument (" [QR ] invalid shape of Q." );
15211520 if (R.numRows () != n || R.numCols () != n)
1522- throw std::invalid_argument (" [QRFactoriser ] invalid shape of R." );
1521+ throw std::invalid_argument (" [QR ] invalid shape of R." );
15231522 // Initialize Q to 1's on diagonal
15241523 std::vector<float > vecQ (m * n, 0 .);
1525- for (int r = 0 ; r < m; r++) {
1526- for (int c = 0 ; c < n; c++) {
1527- if (r == c) { vecQ[r * n + c] = 1 .; }
1528- }
1524+ for (size_t i = 0 ; i < m; i++) {
1525+ vecQ[i * n + i] = 1 .;
15291526 }
15301527 Q.upload (vecQ, rowMajor);
15311528 // Apply Householder reflectors to compute Q
0 commit comments