Skip to content

Commit 199bb73

Browse files
authored
Merge pull request #41 from GPUEngineering/f/ls
QR decomposition and least squares
2 parents f62aa1c + 4852045 commit 199bb73

File tree

3 files changed

+316
-17
lines changed

3 files changed

+316
-17
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ DTensor<float> B(bData, m);
236236
Then, we can solve the system by
237237

238238
```c++
239-
A.leastSquares(B);
239+
A.leastSquaresBatched(B);
240240
```
241241

242242
The `DTensor` `B` will be overwritten with the solution.

include/tensor.cuh

Lines changed: 232 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,15 @@ public:
391391
T minAbs() const;
392392

393393
/**
394-
* Solves for the least squares solution of A \ b.
395-
* A is this tensor and b is the provided tensor.
394+
* Batch solves `A \ b`.
395+
* Solves `bi <- Ai \ bi` for each k-index `i`.
396+
* A is this (m,n,k)-tensor and b is the provided (m,1,k)-tensor.
396397
* A and b must have compatible dimensions (same number of rows and matrices).
397398
* A must be a square or tall matrix (m>=n).
398399
* @param b provided tensor
399-
* @return least squares solution (overwrites b)
400+
* @return least squares solutions (overwrites (n,1,k)-part of b)
400401
*/
401-
void leastSquares(DTensor &b);
402+
void leastSquaresBatched(DTensor &b);
402403

403404
/**
404405
* Batched `C <- bC + a*A*B`.
@@ -884,17 +885,17 @@ inline void DTensor<float>::addAB(const DTensor<float> &A, const DTensor<float>
884885
}
885886

886887
template<>
887-
inline void DTensor<double>::leastSquares(DTensor &B) {
888+
inline void DTensor<double>::leastSquaresBatched(DTensor &B) {
888889
size_t batchSize = numMats();
889890
size_t nColsB = B.numCols();
890891
if (B.numRows() != m_numRows)
891-
throw std::invalid_argument("[Least squares] rhs rows does not equal lhs rows");
892+
throw std::invalid_argument("[Least squares batched] rhs rows does not equal lhs rows");
892893
if (nColsB != 1)
893-
throw std::invalid_argument("[Least squares] rhs are not vectors");
894+
throw std::invalid_argument("[Least squares batched] rhs are not vectors");
894895
if (B.numMats() != batchSize)
895-
throw std::invalid_argument("[Least squares] rhs numMats does not equal lhs numMats");
896+
throw std::invalid_argument("[Least squares batched] rhs numMats does not equal lhs numMats");
896897
if (m_numCols > m_numRows)
897-
throw std::invalid_argument("[Least squares] supports square or tall matrices only");
898+
throw std::invalid_argument("[Least squares batched] supports square or tall matrices only");
898899
int info = 0;
899900
DTensor<int> infoArray(batchSize);
900901
DTensor<double *> As = pointersToMatrices();
@@ -914,17 +915,17 @@ inline void DTensor<double>::leastSquares(DTensor &B) {
914915
}
915916

916917
template<>
917-
inline void DTensor<float>::leastSquares(DTensor &B) {
918+
inline void DTensor<float>::leastSquaresBatched(DTensor &B) {
918919
size_t batchSize = numMats();
919920
size_t nColsB = B.numCols();
920921
if (B.numRows() != m_numRows)
921-
throw std::invalid_argument("[Least squares] rhs rows does not equal lhs rows");
922+
throw std::invalid_argument("[Least squares batched] rhs rows does not equal lhs rows");
922923
if (nColsB != 1)
923-
throw std::invalid_argument("[Least squares] rhs are not vectors");
924+
throw std::invalid_argument("[Least squares batched] rhs are not vectors");
924925
if (B.numMats() != batchSize)
925-
throw std::invalid_argument("[Least squares] rhs numMats does not equal lhs numMats");
926+
throw std::invalid_argument("[Least squares batched] rhs numMats does not equal lhs numMats");
926927
if (m_numCols > m_numRows)
927-
throw std::invalid_argument("[Least squares] supports square or tall matrices only");
928+
throw std::invalid_argument("[Least squares batched] supports square or tall matrices only");
928929
int info = 0;
929930
DTensor<int> infoArray(batchSize);
930931
DTensor<float *> As = pointersToMatrices();
@@ -1238,7 +1239,7 @@ private:
12381239
public:
12391240

12401241
CholeskyFactoriser(DTensor<T> &A) {
1241-
if (A.numMats() > 1) throw std::invalid_argument("[Cholesky] 3D tensors require `CholeskyBatchFactoriser");
1242+
if (A.numMats() > 1) throw std::invalid_argument("[Cholesky] 3D tensors require `CholeskyBatchFactoriser`");
12421243
if (A.numRows() != A.numCols()) throw std::invalid_argument("[Cholesky] Matrix A must be square");
12431244
m_matrix = &A;
12441245
computeWorkspaceSize();
@@ -1328,6 +1329,222 @@ inline int CholeskyFactoriser<float>::solve(DTensor<float> &rhs) {
13281329
}
13291330

13301331

1332+
/* ================================================================================================
1333+
* QR DECOMPOSITION (QR)
1334+
* ================================================================================================ */
1335+
1336+
/**
1337+
* QR decomposition (QR) needs a workspace to be setup for cuSolver before factorisation.
1338+
* This object can be setup for a specific type and size of (m,n,1)-tensor (i.e., a matrix).
1339+
* Then, many same-type-(m,n,1)-tensor can be factorised using this object's workspace
1340+
* @tparam T data type of (m,n,1)-tensor to be factorised (must be float or double)
1341+
*/
1342+
TEMPLATE_WITH_TYPE_T TEMPLATE_CONSTRAINT_REQUIRES_FPX
1343+
class QRFactoriser {
1344+
1345+
private:
1346+
int m_workspaceSize = 0; ///< Size of workspace needed for LS
1347+
std::unique_ptr<DTensor<int>> m_info; ///< Status code of computation
1348+
std::unique_ptr<DTensor<T>> m_householder; ///< For storing householder reflectors
1349+
std::unique_ptr<DTensor<T>> m_workspace; ///< Workspace for LS
1350+
DTensor<T> *m_matrix; ///< Lhs matrix template. Do not destroy!
1351+
1352+
/**
1353+
* Computes the workspace size required by cuSolver.
1354+
*/
1355+
void computeWorkspaceSize();
1356+
1357+
public:
1358+
1359+
QRFactoriser(DTensor<T> &A) {
1360+
if (A.numMats() > 1) throw std::invalid_argument("[QR] 3D tensors require `leastSquaresBatched`");
1361+
if (A.numRows() < A.numCols()) throw std::invalid_argument("[QR] Matrix A must be tall or square");
1362+
m_matrix = &A;
1363+
computeWorkspaceSize();
1364+
m_workspace = std::make_unique<DTensor<T>>(m_workspaceSize);
1365+
m_householder = std::make_unique<DTensor<T>>(m_matrix->numCols());
1366+
m_info = std::make_unique<DTensor<int>>(1);
1367+
}
1368+
1369+
/**
1370+
* Factorise matrix.
1371+
* @return status code of computation
1372+
*/
1373+
int factorise();
1374+
1375+
/**
1376+
* Solves A \ b using the QR of A.
1377+
* A is the matrix that is factorised and b is the provided matrix.
1378+
* A and b must have compatible dimensions (same number of rows and matrices=1).
1379+
* A must be tall or square (m>=n).
1380+
* @param b provided matrix
1381+
* @return status code of computation
1382+
*/
1383+
int leastSquares(DTensor<T> &);
1384+
1385+
/**
1386+
* Populate the given tensors with Q and R.
1387+
* Caution! This is an inefficient method: only to be used for debugging.
1388+
* @return resulting Q and R from factorisation
1389+
*/
1390+
int getQR(DTensor<T> &, DTensor<T> &);
1391+
1392+
};
1393+
1394+
template<>
1395+
inline void QRFactoriser<double>::computeWorkspaceSize() {
1396+
size_t m = m_matrix->numRows();
1397+
size_t n = m_matrix->numCols();
1398+
gpuErrChk(cusolverDnDgeqrf_bufferSize(Session::getInstance().cuSolverHandle(),
1399+
m, n,
1400+
nullptr, m,
1401+
&m_workspaceSize));
1402+
}
1403+
1404+
template<>
1405+
inline void QRFactoriser<float>::computeWorkspaceSize() {
1406+
size_t m = m_matrix->numRows();
1407+
size_t n = m_matrix->numCols();
1408+
gpuErrChk(cusolverDnSgeqrf_bufferSize(Session::getInstance().cuSolverHandle(),
1409+
m, n,
1410+
nullptr, m,
1411+
&m_workspaceSize));
1412+
}
1413+
1414+
template<>
1415+
inline int QRFactoriser<double>::factorise() {
1416+
size_t m = m_matrix->numRows();
1417+
size_t n = m_matrix->numCols();
1418+
gpuErrChk(cusolverDnDgeqrf(Session::getInstance().cuSolverHandle(),
1419+
m, n,
1420+
m_matrix->raw(), m,
1421+
m_householder->raw(),
1422+
m_workspace->raw(), m_workspaceSize,
1423+
m_info->raw()));
1424+
return (*m_info)(0);
1425+
}
1426+
1427+
1428+
template<>
1429+
inline int QRFactoriser<float>::factorise() {
1430+
size_t m = m_matrix->numRows();
1431+
size_t n = m_matrix->numCols();
1432+
gpuErrChk(cusolverDnSgeqrf(Session::getInstance().cuSolverHandle(),
1433+
m, n,
1434+
m_matrix->raw(), m,
1435+
m_householder->raw(),
1436+
m_workspace->raw(), m_workspaceSize,
1437+
m_info->raw()));
1438+
return (*m_info)(0);
1439+
}
1440+
1441+
template<>
1442+
inline int QRFactoriser<double>::leastSquares(DTensor<double> &rhs) {
1443+
size_t m = m_matrix->numRows();
1444+
size_t n = m_matrix->numCols();
1445+
double alpha = 1.;
1446+
gpuErrChk(cusolverDnDormqr(Session::getInstance().cuSolverHandle(),
1447+
CUBLAS_SIDE_LEFT, CUBLAS_OP_T, m, 1, n,
1448+
m_matrix->raw(), m,
1449+
m_householder->raw(),
1450+
rhs.raw(), m,
1451+
m_workspace->raw(), m_workspaceSize,
1452+
m_info->raw()));
1453+
gpuErrChk(cublasDtrsm(Session::getInstance().cuBlasHandle(),
1454+
CUBLAS_SIDE_LEFT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, n, 1,
1455+
&alpha,
1456+
m_matrix->raw(), m,
1457+
rhs.raw(), m));
1458+
return (*m_info)(0);
1459+
}
1460+
1461+
template<>
1462+
inline int QRFactoriser<float>::leastSquares(DTensor<float> &rhs) {
1463+
size_t m = m_matrix->numRows();
1464+
size_t n = m_matrix->numCols();
1465+
float alpha = 1.;
1466+
gpuErrChk(cusolverDnSormqr(Session::getInstance().cuSolverHandle(),
1467+
CUBLAS_SIDE_LEFT, CUBLAS_OP_T, m, 1, n,
1468+
m_matrix->raw(), m,
1469+
m_householder->raw(),
1470+
rhs.raw(), m,
1471+
m_workspace->raw(), m_workspaceSize,
1472+
m_info->raw()));
1473+
gpuErrChk(cublasStrsm(Session::getInstance().cuBlasHandle(),
1474+
CUBLAS_SIDE_LEFT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, n, 1,
1475+
&alpha,
1476+
m_matrix->raw(), m,
1477+
rhs.raw(), m));
1478+
return (*m_info)(0);
1479+
}
1480+
1481+
template<>
1482+
inline int QRFactoriser<double>::getQR(DTensor<double> &Q, DTensor<double> &R) {
1483+
size_t m = m_matrix->numRows();
1484+
size_t n = m_matrix->numCols();
1485+
if (Q.numRows() != m || Q.numCols() != n)
1486+
throw std::invalid_argument("[QR] invalid shape of Q.");
1487+
if (R.numRows() != n || R.numCols() != n)
1488+
throw std::invalid_argument("[QR] invalid shape of R.");
1489+
// Initialize Q to 1's on diagonal
1490+
std::vector<double> vecQ(m * n, 0.);
1491+
for (size_t i = 0; i < m; i++) {
1492+
vecQ[i * n + i] = 1.;
1493+
}
1494+
Q.upload(vecQ, rowMajor);
1495+
// Apply Householder reflectors to compute Q
1496+
gpuErrChk(cusolverDnDormqr(Session::getInstance().cuSolverHandle(),
1497+
CUBLAS_SIDE_LEFT, CUBLAS_OP_N, m, n, n,
1498+
m_matrix->raw(), m,
1499+
m_householder->raw(),
1500+
Q.raw(), m,
1501+
m_workspace->raw(), m_workspaceSize,
1502+
m_info->raw()));
1503+
// Extract upper triangular R
1504+
std::vector<double> vecR(n * n, 0.);
1505+
for (size_t r = 0; r < n; r++) {
1506+
for (size_t c = r; c < n; c++) {
1507+
vecR[r * n + c] = (*m_matrix)(r, c);
1508+
}
1509+
}
1510+
R.upload(vecR, rowMajor);
1511+
return (*m_info)(0);
1512+
}
1513+
1514+
template<>
1515+
inline int QRFactoriser<float>::getQR(DTensor<float> &Q, DTensor<float> &R) {
1516+
size_t m = m_matrix->numRows();
1517+
size_t n = m_matrix->numCols();
1518+
if (Q.numRows() != m || Q.numCols() != n)
1519+
throw std::invalid_argument("[QR] invalid shape of Q.");
1520+
if (R.numRows() != n || R.numCols() != n)
1521+
throw std::invalid_argument("[QR] invalid shape of R.");
1522+
// Initialize Q to 1's on diagonal
1523+
std::vector<float> vecQ(m * n, 0.);
1524+
for (size_t i = 0; i < m; i++) {
1525+
vecQ[i * n + i] = 1.;
1526+
}
1527+
Q.upload(vecQ, rowMajor);
1528+
// Apply Householder reflectors to compute Q
1529+
gpuErrChk(cusolverDnSormqr(Session::getInstance().cuSolverHandle(),
1530+
CUBLAS_SIDE_LEFT, CUBLAS_OP_N, m, n, n,
1531+
m_matrix->raw(), m,
1532+
m_householder->raw(),
1533+
Q.raw(), m,
1534+
m_workspace->raw(), m_workspaceSize,
1535+
m_info->raw()));
1536+
// Extract upper triangular R
1537+
std::vector<float> vecR(n * n, 0.);
1538+
for (size_t r = 0; r < n; r++) {
1539+
for (size_t c = r; c < n; c++) {
1540+
vecR[r * n + c] = (*m_matrix)(r, c);
1541+
}
1542+
}
1543+
R.upload(vecR, rowMajor);
1544+
return (*m_info)(0);
1545+
}
1546+
1547+
13311548
/* ================================================================================================
13321549
* Nullspace (N)
13331550
* ================================================================================================ */

0 commit comments

Comments
 (0)