Skip to content

Commit 5c1ed5d

Browse files
author
Ruairi Moran
committed
fix cholesky
1 parent 6c498cc commit 5c1ed5d

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

include/tensor.cuh

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,7 +1326,7 @@ public:
13261326
};
13271327

13281328
template<>
1329-
void CholeskyBatchFactoriser<double>::factorise() {
1329+
inline void CholeskyBatchFactoriser<double>::factorise() {
13301330
if (m_factorisationDone) return;
13311331
DTensor<double *> ptrA = m_matrix->pointersToMatrices();
13321332
gpuErrChk(cusolverDnDpotrfBatched(Session::getInstance().cuSolverHandle(),
@@ -1340,7 +1340,7 @@ void CholeskyBatchFactoriser<double>::factorise() {
13401340
}
13411341

13421342
template<>
1343-
void CholeskyBatchFactoriser<float>::factorise() {
1343+
inline void CholeskyBatchFactoriser<float>::factorise() {
13441344
if (m_factorisationDone) return;
13451345
DTensor<float *> ptrA = m_matrix->pointersToMatrices();
13461346
gpuErrChk(cusolverDnSpotrfBatched(Session::getInstance().cuSolverHandle(),
@@ -1354,8 +1354,11 @@ void CholeskyBatchFactoriser<float>::factorise() {
13541354
}
13551355

13561356
template<>
1357-
void CholeskyBatchFactoriser<double>::solve(DTensor<double> &b) {
1357+
inline void CholeskyBatchFactoriser<double>::solve(DTensor<double> &b) {
13581358
if (!m_factorisationDone) throw std::logic_error("[CholeskyBatchSolve] no factor to solve with");
1359+
if (m_numRows != b.numRows() || m_numMats != b.numMats()) {
1360+
throw std::invalid_argument("[CholeskyBatchSolve] A and b incompatible");
1361+
13591362
if (b.numCols() != 1) throw std::invalid_argument("[CholeskyBatchSolve] only supports `b` with one column");
13601363
DTensor<double *> ptrA = m_matrix->pointersToMatrices();
13611364
DTensor<double *> ptrB = b.pointersToMatrices();
@@ -1372,8 +1375,11 @@ void CholeskyBatchFactoriser<double>::solve(DTensor<double> &b) {
13721375
}
13731376

13741377
template<>
1375-
void CholeskyBatchFactoriser<float>::solve(DTensor<float> &b) {
1378+
inline void CholeskyBatchFactoriser<float>::solve(DTensor<float> &b) {
13761379
if (!m_factorisationDone) throw std::logic_error("[CholeskyBatchSolve] no factor to solve with");
1380+
if (m_numRows != b.numRows() || m_numMats != b.numMats()) {
1381+
throw std::invalid_argument("[CholeskyBatchSolve] A and b incompatible");
1382+
}
13771383
if (b.numCols() != 1) throw std::invalid_argument("[CholeskyBatchSolve] only supports `b` with one column");
13781384
DTensor<float *> ptrA = m_matrix->pointersToMatrices();
13791385
DTensor<float *> ptrB = b.pointersToMatrices();

0 commit comments

Comments
 (0)