@@ -391,14 +391,15 @@ public:
391391 T minAbs () const ;
392392
393393 /* *
394- * Solves for the least squares solution of A \ b.
395- * A is this tensor and b is the provided tensor.
394+ * Batch solves `A \ b`.
395+ * Solves `bi <- Ai \ bi` for each k-index `i`.
396+ * A is this (m,n,k)-tensor and b is the provided (m,1,k)-tensor.
396397 * A and b must have compatible dimensions (same number of rows and matrices).
397398 * A must be a square or tall matrix (m>=n).
398399 * @param b provided tensor
399- * @return least squares solution (overwrites b)
400+ * @return least squares solutions (overwrites (n,1,k)-part of b)
400401 */
401- void leastSquares (DTensor &b);
402+ void leastSquaresBatched (DTensor &b);
402403
403404 /* *
404405 * Batched `C <- bC + a*A*B`.
@@ -884,17 +885,17 @@ inline void DTensor<float>::addAB(const DTensor<float> &A, const DTensor<float>
884885}
885886
886887template <>
887- inline void DTensor<double >::leastSquares (DTensor &B) {
888+ inline void DTensor<double >::leastSquaresBatched (DTensor &B) {
888889 size_t batchSize = numMats ();
889890 size_t nColsB = B.numCols ();
890891 if (B.numRows () != m_numRows)
891- throw std::invalid_argument (" [Least squares] rhs rows does not equal lhs rows" );
892+ throw std::invalid_argument (" [Least squares batched ] rhs rows does not equal lhs rows" );
892893 if (nColsB != 1 )
893- throw std::invalid_argument (" [Least squares] rhs are not vectors" );
894+ throw std::invalid_argument (" [Least squares batched ] rhs are not vectors" );
894895 if (B.numMats () != batchSize)
895- throw std::invalid_argument (" [Least squares] rhs numMats does not equal lhs numMats" );
896+ throw std::invalid_argument (" [Least squares batched ] rhs numMats does not equal lhs numMats" );
896897 if (m_numCols > m_numRows)
897- throw std::invalid_argument (" [Least squares] supports square or tall matrices only" );
898+ throw std::invalid_argument (" [Least squares batched ] supports square or tall matrices only" );
898899 int info = 0 ;
899900 DTensor<int > infoArray (batchSize);
900901 DTensor<double *> As = pointersToMatrices ();
@@ -914,17 +915,17 @@ inline void DTensor<double>::leastSquares(DTensor &B) {
914915}
915916
916917template <>
917- inline void DTensor<float >::leastSquares (DTensor &B) {
918+ inline void DTensor<float >::leastSquaresBatched (DTensor &B) {
918919 size_t batchSize = numMats ();
919920 size_t nColsB = B.numCols ();
920921 if (B.numRows () != m_numRows)
921- throw std::invalid_argument (" [Least squares] rhs rows does not equal lhs rows" );
922+ throw std::invalid_argument (" [Least squares batched ] rhs rows does not equal lhs rows" );
922923 if (nColsB != 1 )
923- throw std::invalid_argument (" [Least squares] rhs are not vectors" );
924+ throw std::invalid_argument (" [Least squares batched ] rhs are not vectors" );
924925 if (B.numMats () != batchSize)
925- throw std::invalid_argument (" [Least squares] rhs numMats does not equal lhs numMats" );
926+ throw std::invalid_argument (" [Least squares batched ] rhs numMats does not equal lhs numMats" );
926927 if (m_numCols > m_numRows)
927- throw std::invalid_argument (" [Least squares] supports square or tall matrices only" );
928+ throw std::invalid_argument (" [Least squares batched ] supports square or tall matrices only" );
928929 int info = 0 ;
929930 DTensor<int > infoArray (batchSize);
930931 DTensor<float *> As = pointersToMatrices ();
@@ -1238,7 +1239,7 @@ private:
12381239public:
12391240
12401241 CholeskyFactoriser (DTensor<T> &A) {
1241- if (A.numMats () > 1 ) throw std::invalid_argument (" [Cholesky] 3D tensors require `CholeskyBatchFactoriser" );
1242+ if (A.numMats () > 1 ) throw std::invalid_argument (" [Cholesky] 3D tensors require `CholeskyBatchFactoriser` " );
12421243 if (A.numRows () != A.numCols ()) throw std::invalid_argument (" [Cholesky] Matrix A must be square" );
12431244 m_matrix = &A;
12441245 computeWorkspaceSize ();
@@ -1328,6 +1329,222 @@ inline int CholeskyFactoriser<float>::solve(DTensor<float> &rhs) {
13281329}
13291330
13301331
1332+ /* ================================================================================================
1333+ * QR DECOMPOSITION (QR)
1334+ * ================================================================================================ */
1335+
1336+ /* *
1337+ * QR decomposition (QR) needs a workspace to be setup for cuSolver before factorisation.
1338+ * This object can be setup for a specific type and size of (m,n,1)-tensor (i.e., a matrix).
1339+ * Then, many same-type-(m,n,1)-tensor can be factorised using this object's workspace
1340+ * @tparam T data type of (m,n,1)-tensor to be factorised (must be float or double)
1341+ */
1342+ TEMPLATE_WITH_TYPE_T TEMPLATE_CONSTRAINT_REQUIRES_FPX
1343+ class QRFactoriser {
1344+
1345+ private:
1346+ int m_workspaceSize = 0 ; // /< Size of workspace needed for LS
1347+ std::unique_ptr<DTensor<int >> m_info; // /< Status code of computation
1348+ std::unique_ptr<DTensor<T>> m_householder; // /< For storing householder reflectors
1349+ std::unique_ptr<DTensor<T>> m_workspace; // /< Workspace for LS
1350+ DTensor<T> *m_matrix; // /< Lhs matrix template. Do not destroy!
1351+
1352+ /* *
1353+ * Computes the workspace size required by cuSolver.
1354+ */
1355+ void computeWorkspaceSize ();
1356+
1357+ public:
1358+
1359+ QRFactoriser (DTensor<T> &A) {
1360+ if (A.numMats () > 1 ) throw std::invalid_argument (" [QR] 3D tensors require `leastSquaresBatched`" );
1361+ if (A.numRows () < A.numCols ()) throw std::invalid_argument (" [QR] Matrix A must be tall or square" );
1362+ m_matrix = &A;
1363+ computeWorkspaceSize ();
1364+ m_workspace = std::make_unique<DTensor<T>>(m_workspaceSize);
1365+ m_householder = std::make_unique<DTensor<T>>(m_matrix->numCols ());
1366+ m_info = std::make_unique<DTensor<int >>(1 );
1367+ }
1368+
1369+ /* *
1370+ * Factorise matrix.
1371+ * @return status code of computation
1372+ */
1373+ int factorise ();
1374+
1375+ /* *
1376+ * Solves A \ b using the QR of A.
1377+ * A is the matrix that is factorised and b is the provided matrix.
1378+ * A and b must have compatible dimensions (same number of rows and matrices=1).
1379+ * A must be tall or square (m>=n).
1380+ * @param b provided matrix
1381+ * @return status code of computation
1382+ */
1383+ int leastSquares (DTensor<T> &);
1384+
1385+ /* *
1386+ * Populate the given tensors with Q and R.
1387+ * Caution! This is an inefficient method: only to be used for debugging.
1388+ * @return resulting Q and R from factorisation
1389+ */
1390+ int getQR (DTensor<T> &, DTensor<T> &);
1391+
1392+ };
1393+
1394+ template <>
1395+ inline void QRFactoriser<double >::computeWorkspaceSize() {
1396+ size_t m = m_matrix->numRows ();
1397+ size_t n = m_matrix->numCols ();
1398+ gpuErrChk (cusolverDnDgeqrf_bufferSize (Session::getInstance ().cuSolverHandle (),
1399+ m, n,
1400+ nullptr , m,
1401+ &m_workspaceSize));
1402+ }
1403+
1404+ template <>
1405+ inline void QRFactoriser<float >::computeWorkspaceSize() {
1406+ size_t m = m_matrix->numRows ();
1407+ size_t n = m_matrix->numCols ();
1408+ gpuErrChk (cusolverDnSgeqrf_bufferSize (Session::getInstance ().cuSolverHandle (),
1409+ m, n,
1410+ nullptr , m,
1411+ &m_workspaceSize));
1412+ }
1413+
1414+ template <>
1415+ inline int QRFactoriser<double >::factorise() {
1416+ size_t m = m_matrix->numRows ();
1417+ size_t n = m_matrix->numCols ();
1418+ gpuErrChk (cusolverDnDgeqrf (Session::getInstance ().cuSolverHandle (),
1419+ m, n,
1420+ m_matrix->raw (), m,
1421+ m_householder->raw (),
1422+ m_workspace->raw (), m_workspaceSize,
1423+ m_info->raw ()));
1424+ return (*m_info)(0 );
1425+ }
1426+
1427+
1428+ template <>
1429+ inline int QRFactoriser<float >::factorise() {
1430+ size_t m = m_matrix->numRows ();
1431+ size_t n = m_matrix->numCols ();
1432+ gpuErrChk (cusolverDnSgeqrf (Session::getInstance ().cuSolverHandle (),
1433+ m, n,
1434+ m_matrix->raw (), m,
1435+ m_householder->raw (),
1436+ m_workspace->raw (), m_workspaceSize,
1437+ m_info->raw ()));
1438+ return (*m_info)(0 );
1439+ }
1440+
1441+ template <>
1442+ inline int QRFactoriser<double >::leastSquares(DTensor<double > &rhs) {
1443+ size_t m = m_matrix->numRows ();
1444+ size_t n = m_matrix->numCols ();
1445+ double alpha = 1 .;
1446+ gpuErrChk (cusolverDnDormqr (Session::getInstance ().cuSolverHandle (),
1447+ CUBLAS_SIDE_LEFT, CUBLAS_OP_T, m, 1 , n,
1448+ m_matrix->raw (), m,
1449+ m_householder->raw (),
1450+ rhs.raw (), m,
1451+ m_workspace->raw (), m_workspaceSize,
1452+ m_info->raw ()));
1453+ gpuErrChk (cublasDtrsm (Session::getInstance ().cuBlasHandle (),
1454+ CUBLAS_SIDE_LEFT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, n, 1 ,
1455+ &alpha,
1456+ m_matrix->raw (), m,
1457+ rhs.raw (), m));
1458+ return (*m_info)(0 );
1459+ }
1460+
1461+ template <>
1462+ inline int QRFactoriser<float >::leastSquares(DTensor<float > &rhs) {
1463+ size_t m = m_matrix->numRows ();
1464+ size_t n = m_matrix->numCols ();
1465+ float alpha = 1 .;
1466+ gpuErrChk (cusolverDnSormqr (Session::getInstance ().cuSolverHandle (),
1467+ CUBLAS_SIDE_LEFT, CUBLAS_OP_T, m, 1 , n,
1468+ m_matrix->raw (), m,
1469+ m_householder->raw (),
1470+ rhs.raw (), m,
1471+ m_workspace->raw (), m_workspaceSize,
1472+ m_info->raw ()));
1473+ gpuErrChk (cublasStrsm (Session::getInstance ().cuBlasHandle (),
1474+ CUBLAS_SIDE_LEFT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N, CUBLAS_DIAG_NON_UNIT, n, 1 ,
1475+ &alpha,
1476+ m_matrix->raw (), m,
1477+ rhs.raw (), m));
1478+ return (*m_info)(0 );
1479+ }
1480+
1481+ template <>
1482+ inline int QRFactoriser<double >::getQR(DTensor<double > &Q, DTensor<double > &R) {
1483+ size_t m = m_matrix->numRows ();
1484+ size_t n = m_matrix->numCols ();
1485+ if (Q.numRows () != m || Q.numCols () != n)
1486+ throw std::invalid_argument (" [QR] invalid shape of Q." );
1487+ if (R.numRows () != n || R.numCols () != n)
1488+ throw std::invalid_argument (" [QR] invalid shape of R." );
1489+ // Initialize Q to 1's on diagonal
1490+ std::vector<double > vecQ (m * n, 0 .);
1491+ for (size_t i = 0 ; i < m; i++) {
1492+ vecQ[i * n + i] = 1 .;
1493+ }
1494+ Q.upload (vecQ, rowMajor);
1495+ // Apply Householder reflectors to compute Q
1496+ gpuErrChk (cusolverDnDormqr (Session::getInstance ().cuSolverHandle (),
1497+ CUBLAS_SIDE_LEFT, CUBLAS_OP_N, m, n, n,
1498+ m_matrix->raw (), m,
1499+ m_householder->raw (),
1500+ Q.raw (), m,
1501+ m_workspace->raw (), m_workspaceSize,
1502+ m_info->raw ()));
1503+ // Extract upper triangular R
1504+ std::vector<double > vecR (n * n, 0 .);
1505+ for (size_t r = 0 ; r < n; r++) {
1506+ for (size_t c = r; c < n; c++) {
1507+ vecR[r * n + c] = (*m_matrix)(r, c);
1508+ }
1509+ }
1510+ R.upload (vecR, rowMajor);
1511+ return (*m_info)(0 );
1512+ }
1513+
1514+ template <>
1515+ inline int QRFactoriser<float >::getQR(DTensor<float > &Q, DTensor<float > &R) {
1516+ size_t m = m_matrix->numRows ();
1517+ size_t n = m_matrix->numCols ();
1518+ if (Q.numRows () != m || Q.numCols () != n)
1519+ throw std::invalid_argument (" [QR] invalid shape of Q." );
1520+ if (R.numRows () != n || R.numCols () != n)
1521+ throw std::invalid_argument (" [QR] invalid shape of R." );
1522+ // Initialize Q to 1's on diagonal
1523+ std::vector<float > vecQ (m * n, 0 .);
1524+ for (size_t i = 0 ; i < m; i++) {
1525+ vecQ[i * n + i] = 1 .;
1526+ }
1527+ Q.upload (vecQ, rowMajor);
1528+ // Apply Householder reflectors to compute Q
1529+ gpuErrChk (cusolverDnSormqr (Session::getInstance ().cuSolverHandle (),
1530+ CUBLAS_SIDE_LEFT, CUBLAS_OP_N, m, n, n,
1531+ m_matrix->raw (), m,
1532+ m_householder->raw (),
1533+ Q.raw (), m,
1534+ m_workspace->raw (), m_workspaceSize,
1535+ m_info->raw ()));
1536+ // Extract upper triangular R
1537+ std::vector<float > vecR (n * n, 0 .);
1538+ for (size_t r = 0 ; r < n; r++) {
1539+ for (size_t c = r; c < n; c++) {
1540+ vecR[r * n + c] = (*m_matrix)(r, c);
1541+ }
1542+ }
1543+ R.upload (vecR, rowMajor);
1544+ return (*m_info)(0 );
1545+ }
1546+
1547+
13311548/* ================================================================================================
13321549 * Nullspace (N)
13331550 * ================================================================================================ */
0 commit comments