You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Now we can remove `_th_potri`!
Compared to the original TH-based `cholesky_inverse`, complex (pytorch#33152) and batched inputs (pytorch#7500) are now supported both on CPU and CUDA.
Closespytorch#24685.
Closespytorch#24543.
Ref. pytorch#49421, pytorch#42666
Pull Request resolved: pytorch#50269
Reviewed By: bdhirsh
Differential Revision: D26047548
Pulled By: anjali411
fbshipit-source-id: e4f191e39c684f241b7cb0f4b4c025de082cccef
Copy file name to clipboardExpand all lines: aten/src/ATen/native/BatchLinearAlgebra.cpp
+97-3Lines changed: 97 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -49,6 +49,12 @@ extern "C" void cpotrf_(char *uplo, int *n, std::complex<float> *a, int *lda, in
49
49
extern"C"voiddpotrf_(char *uplo, int *n, double *a, int *lda, int *info);
50
50
extern"C"voidspotrf_(char *uplo, int *n, float *a, int *lda, int *info);
51
51
52
+
// potri
53
+
extern"C"voidzpotri_(char *uplo, int *n, std::complex<double> *a, int *lda, int *info);
54
+
extern"C"voidcpotri_(char *uplo, int *n, std::complex<float> *a, int *lda, int *info);
55
+
extern"C"voiddpotri_(char *uplo, int *n, double *a, int *lda, int *info);
56
+
extern"C"voidspotri_(char *uplo, int *n, float *a, int *lda, int *info);
57
+
52
58
// trtrs
53
59
extern"C"voidztrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info);
54
60
extern"C"voidctrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info);
@@ -237,6 +243,22 @@ template<> void lapackCholesky<float>(char uplo, int n, float *a, int lda, int *
237
243
spotrf_(&uplo, &n, a, &lda, info);
238
244
}
239
245
246
+
template<> void lapackCholeskyInverse<c10::complex<double>>(char uplo, int n, c10::complex<double> *a, int lda, int *info) {
template<> void lapackCholeskyInverse<double>(char uplo, int n, double *a, int lda, int *info) {
255
+
dpotri_(&uplo, &n, a, &lda, info);
256
+
}
257
+
258
+
template<> void lapackCholeskyInverse<float>(char uplo, int n, float *a, int lda, int *info) {
259
+
spotri_(&uplo, &n, a, &lda, info);
260
+
}
261
+
240
262
template<> void lapackTriangularSolve<c10::complex<double>>(char uplo, char trans, char diag, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb, int *info) {
Copy file name to clipboardExpand all lines: aten/src/ATen/native/BatchLinearAlgebra.h
+7Lines changed: 7 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -14,6 +14,9 @@ namespace at { namespace native {
14
14
// Define per-batch functions to be used in the implementation of batched
15
15
// linear algebra operations
16
16
17
+
template<classscalar_t>
18
+
voidlapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
19
+
17
20
template<classscalar_t, classvalue_t=scalar_t>
18
21
voidlapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
19
22
@@ -22,6 +25,10 @@ void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scala
22
25
23
26
#endif
24
27
28
+
using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool/*upper*/);
0 commit comments