@@ -1326,7 +1326,7 @@ public:
13261326};
13271327
13281328template <>
1329- void CholeskyBatchFactoriser<double >::factorise() {
1329+ inline void CholeskyBatchFactoriser<double >::factorise() {
13301330 if (m_factorisationDone) return ;
13311331 DTensor<double *> ptrA = m_matrix->pointersToMatrices ();
13321332 gpuErrChk (cusolverDnDpotrfBatched (Session::getInstance ().cuSolverHandle (),
@@ -1340,7 +1340,7 @@ void CholeskyBatchFactoriser<double>::factorise() {
13401340}
13411341
13421342template <>
1343- void CholeskyBatchFactoriser<float >::factorise() {
1343+ inline void CholeskyBatchFactoriser<float >::factorise() {
13441344 if (m_factorisationDone) return ;
13451345 DTensor<float *> ptrA = m_matrix->pointersToMatrices ();
13461346 gpuErrChk (cusolverDnSpotrfBatched (Session::getInstance ().cuSolverHandle (),
@@ -1354,8 +1354,11 @@ void CholeskyBatchFactoriser<float>::factorise() {
13541354}
13551355
13561356template <>
1357- void CholeskyBatchFactoriser<double >::solve(DTensor<double > &b) {
1357+ inline void CholeskyBatchFactoriser<double >::solve(DTensor<double > &b) {
13581358 if (!m_factorisationDone) throw std::logic_error (" [CholeskyBatchSolve] no factor to solve with" );
1359+ if (m_numRows != b.numRows () || m_numMats != b.numMats ()) {
1360+ throw std::invalid_argument (" [CholeskyBatchSolve] A and b incompatible" );
1361+
13591362 if (b.numCols () != 1 ) throw std::invalid_argument (" [CholeskyBatchSolve] only supports `b` with one column" );
13601363 DTensor<double *> ptrA = m_matrix->pointersToMatrices ();
13611364 DTensor<double *> ptrB = b.pointersToMatrices ();
@@ -1372,8 +1375,11 @@ void CholeskyBatchFactoriser<double>::solve(DTensor<double> &b) {
13721375}
13731376
13741377template <>
1375- void CholeskyBatchFactoriser<float >::solve(DTensor<float > &b) {
1378+ inline void CholeskyBatchFactoriser<float >::solve (DTensor<float > &b) {
13761379 if (!m_factorisationDone) throw std::logic_error (" [CholeskyBatchSolve] no factor to solve with" );
1380+ if (m_numRows != b.numRows () || m_numMats != b.numMats ()) {
1381+ throw std::invalid_argument (" [CholeskyBatchSolve] A and b incompatible" );
1382+ }
13771383 if (b.numCols () != 1 ) throw std::invalid_argument (" [CholeskyBatchSolve] only supports `b` with one column" );
13781384 DTensor<float *> ptrA = m_matrix->pointersToMatrices ();
13791385 DTensor<float *> ptrB = b.pointersToMatrices ();
0 commit comments