Skip to content

Commit 4852045

Browse files
author
Ruairi Moran
committed
resolve comments
1 parent adfc224 commit 4852045

File tree

2 files changed

+17
-20
lines changed

2 files changed

+17
-20
lines changed

include/tensor.cuh

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -391,12 +391,13 @@ public:
391391
T minAbs() const;
392392

393393
/**
394-
* Solves for the least squares solution of A \ b.
395-
* A is this tensor and b is the provided tensor.
394+
* Batch solves `A \ b`.
395+
* Solves `bi <- Ai \ bi` for each k-index `i`.
396+
* A is this (m,n,k)-tensor and b is the provided (m,1,k)-tensor.
396397
* A and b must have compatible dimensions (same number of rows and matrices).
397398
* A must be a square or tall matrix (m>=n).
398399
* @param b provided tensor
399-
* @return least squares solution (overwrites b)
400+
* @return least squares solutions (overwrites (n,1,k)-part of b)
400401
*/
401402
void leastSquaresBatched(DTensor &b);
402403

@@ -1356,8 +1357,8 @@ private:
13561357
public:
13571358

13581359
QRFactoriser(DTensor<T> &A) {
1359-
if (A.numMats() > 1) throw std::invalid_argument("[LeastSquares] 3D tensors require `leastSquaresBatched`");
1360-
if (A.numRows() < A.numCols()) throw std::invalid_argument("[Cholesky] Matrix A must be tall or square");
1360+
if (A.numMats() > 1) throw std::invalid_argument("[QR] 3D tensors require `leastSquaresBatched`");
1361+
if (A.numRows() < A.numCols()) throw std::invalid_argument("[QR] Matrix A must be tall or square");
13611362
m_matrix = &A;
13621363
computeWorkspaceSize();
13631364
m_workspace = std::make_unique<DTensor<T>>(m_workspaceSize);
@@ -1372,7 +1373,7 @@ public:
13721373
int factorise();
13731374

13741375
/**
1375-
* Solves for the solution of A \ b using the QR of A.
1376+
* Solves A \ b using the QR of A.
13761377
* A is the matrix that is factorised and b is the provided matrix.
13771378
* A and b must have compatible dimensions (same number of rows and matrices=1).
13781379
* A must be tall or square (m>=n).
@@ -1482,15 +1483,13 @@ inline int QRFactoriser<double>::getQR(DTensor<double> &Q, DTensor<double> &R) {
14821483
size_t m = m_matrix->numRows();
14831484
size_t n = m_matrix->numCols();
14841485
if (Q.numRows() != m || Q.numCols() != n)
1485-
throw std::invalid_argument("[QRFactoriser] invalid shape of Q.");
1486+
throw std::invalid_argument("[QR] invalid shape of Q.");
14861487
if (R.numRows() != n || R.numCols() != n)
1487-
throw std::invalid_argument("[QRFactoriser] invalid shape of R.");
1488+
throw std::invalid_argument("[QR] invalid shape of R.");
14881489
// Initialize Q to 1's on diagonal
14891490
std::vector<double> vecQ(m * n, 0.);
1490-
for (int r = 0; r < m; r++) {
1491-
for (int c = 0; c < n; c++) {
1492-
if (r == c) { vecQ[r * n + c] = 1.; }
1493-
}
1491+
for (size_t i = 0; i < m; i++) {
1492+
vecQ[i * n + i] = 1.;
14941493
}
14951494
Q.upload(vecQ, rowMajor);
14961495
// Apply Householder reflectors to compute Q
@@ -1517,15 +1516,13 @@ inline int QRFactoriser<float>::getQR(DTensor<float> &Q, DTensor<float> &R) {
15171516
size_t m = m_matrix->numRows();
15181517
size_t n = m_matrix->numCols();
15191518
if (Q.numRows() != m || Q.numCols() != n)
1520-
throw std::invalid_argument("[QRFactoriser] invalid shape of Q.");
1519+
throw std::invalid_argument("[QR] invalid shape of Q.");
15211520
if (R.numRows() != n || R.numCols() != n)
1522-
throw std::invalid_argument("[QRFactoriser] invalid shape of R.");
1521+
throw std::invalid_argument("[QR] invalid shape of R.");
15231522
// Initialize Q to 1's on diagonal
15241523
std::vector<float> vecQ(m * n, 0.);
1525-
for (int r = 0; r < m; r++) {
1526-
for (int c = 0; c < n; c++) {
1527-
if (r == c) { vecQ[r * n + c] = 1.; }
1528-
}
1524+
for (size_t i = 0; i < m; i++) {
1525+
vecQ[i * n + i] = 1.;
15291526
}
15301527
Q.upload(vecQ, rowMajor);
15311528
// Apply Householder reflectors to compute Q

test/testTensor.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,7 @@ protected:
10501050

10511051

10521052
/* ---------------------------------------
1053-
* Cholesky factorisation
1053+
* QR factorisation
10541054
* --------------------------------------- */
10551055

10561056
TEMPLATE_WITH_TYPE_T TEMPLATE_CONSTRAINT_REQUIRES_FPX
@@ -1080,7 +1080,7 @@ TEST_F(QRTest, qrFactorisation) {
10801080
}
10811081

10821082
/* ---------------------------------------
1083-
* Cholesky factorisation: solve system
1083+
* QR factorisation: solve least squares
10841084
* --------------------------------------- */
10851085

10861086
TEMPLATE_WITH_TYPE_T TEMPLATE_CONSTRAINT_REQUIRES_FPX

0 commit comments

Comments
 (0)