Skip to content

Commit 6e4746c

Browse files
IvanYashchukfacebook-github-bot
authored andcommitted
Port cholesky_inverse to ATen (pytorch#50269)
Summary: Now we can remove `_th_potri`! Compared to the original TH-based `cholesky_inverse`, complex (pytorch#33152) and batched inputs (pytorch#7500) are now supported both on CPU and CUDA. Closes pytorch#24685. Closes pytorch#24543. Ref. pytorch#49421, pytorch#42666 Pull Request resolved: pytorch#50269 Reviewed By: bdhirsh Differential Revision: D26047548 Pulled By: anjali411 fbshipit-source-id: e4f191e39c684f241b7cb0f4b4c025de082cccef
1 parent 9f6e0de commit 6e4746c

17 files changed

+413
-294
lines changed

aten/src/ATen/LegacyTHFunctionsCPU.cpp

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -686,49 +686,6 @@ std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A) {
686686
}
687687
return std::tuple<Tensor, Tensor>(res1, res2);
688688
}
689-
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper) {
690-
// DeviceGuard omitted
691-
auto dispatch_scalar_type = infer_scalar_type(self);
692-
693-
switch (dispatch_scalar_type) {
694-
case ScalarType::Double: {
695-
auto output_ = checked_dense_tensor_unwrap(output, "output", 0, "_th_potri_out", false, DeviceType::CPU, dispatch_scalar_type);
696-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri_out", false, DeviceType::CPU, dispatch_scalar_type);
697-
THDoubleTensor_potri(output_, self_, upper);
698-
break;
699-
}
700-
case ScalarType::Float: {
701-
auto output_ = checked_dense_tensor_unwrap(output, "output", 0, "_th_potri_out", false, DeviceType::CPU, dispatch_scalar_type);
702-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri_out", false, DeviceType::CPU, dispatch_scalar_type);
703-
THFloatTensor_potri(output_, self_, upper);
704-
break;
705-
}
706-
default:
707-
AT_ERROR("_th_potri_out not supported on CPUType for ", dispatch_scalar_type);
708-
}
709-
return output;
710-
}
711-
Tensor _th_potri(const Tensor & self, bool upper) {
712-
// DeviceGuard omitted
713-
auto dispatch_scalar_type = infer_scalar_type(self);
714-
auto output_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
715-
auto output = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(output_));
716-
switch (dispatch_scalar_type) {
717-
case ScalarType::Double: {
718-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri", false, DeviceType::CPU, dispatch_scalar_type);
719-
THDoubleTensor_potri(output_, self_, upper);
720-
break;
721-
}
722-
case ScalarType::Float: {
723-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri", false, DeviceType::CPU, dispatch_scalar_type);
724-
THFloatTensor_potri(output_, self_, upper);
725-
break;
726-
}
727-
default:
728-
AT_ERROR("_th_potri not supported on CPUType for ", dispatch_scalar_type);
729-
}
730-
return output;
731-
}
732689
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self) {
733690
// DeviceGuard omitted
734691
auto dispatch_scalar_type = infer_scalar_type(self);

aten/src/ATen/LegacyTHFunctionsCPU.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ Tensor & _th_histc_out(Tensor & result, const Tensor & self, int64_t bins, Scala
3838
Tensor _th_histc(const Tensor & self, int64_t bins, Scalar min, Scalar max);
3939
std::tuple<Tensor &,Tensor &> _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A);
4040
std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A);
41-
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper);
42-
Tensor _th_potri(const Tensor & self, bool upper);
4341
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self);
4442
std::tuple<Tensor,Tensor> _th_geqrf(const Tensor & self);
4543
Tensor & _th_ormqr_out(Tensor & result, const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose);

aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,49 +1062,6 @@ std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A) {
10621062
}
10631063
return std::tuple<Tensor, Tensor>(res1, res2);
10641064
}
1065-
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper) {
1066-
// DeviceGuard omitted
1067-
auto dispatch_scalar_type = infer_scalar_type(self);
1068-
1069-
switch (dispatch_scalar_type) {
1070-
case ScalarType::Double: {
1071-
auto output_ = checked_dense_tensor_unwrap(output, "output", 0, "_th_potri_out", false, DeviceType::CUDA, dispatch_scalar_type);
1072-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri_out", false, DeviceType::CUDA, dispatch_scalar_type);
1073-
THCudaDoubleTensor_potri(globalContext().getTHCState(), output_, self_, upper);
1074-
break;
1075-
}
1076-
case ScalarType::Float: {
1077-
auto output_ = checked_dense_tensor_unwrap(output, "output", 0, "_th_potri_out", false, DeviceType::CUDA, dispatch_scalar_type);
1078-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri_out", false, DeviceType::CUDA, dispatch_scalar_type);
1079-
THCudaTensor_potri(globalContext().getTHCState(), output_, self_, upper);
1080-
break;
1081-
}
1082-
default:
1083-
AT_ERROR("_th_potri_out not supported on CUDAType for ", dispatch_scalar_type);
1084-
}
1085-
return output;
1086-
}
1087-
Tensor _th_potri(const Tensor & self, bool upper) {
1088-
// DeviceGuard omitted
1089-
auto dispatch_scalar_type = infer_scalar_type(self);
1090-
auto output_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
1091-
auto output = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(output_));
1092-
switch (dispatch_scalar_type) {
1093-
case ScalarType::Double: {
1094-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri", false, DeviceType::CUDA, dispatch_scalar_type);
1095-
THCudaDoubleTensor_potri(globalContext().getTHCState(), output_, self_, upper);
1096-
break;
1097-
}
1098-
case ScalarType::Float: {
1099-
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri", false, DeviceType::CUDA, dispatch_scalar_type);
1100-
THCudaTensor_potri(globalContext().getTHCState(), output_, self_, upper);
1101-
break;
1102-
}
1103-
default:
1104-
AT_ERROR("_th_potri not supported on CUDAType for ", dispatch_scalar_type);
1105-
}
1106-
return output;
1107-
}
11081065
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self) {
11091066
// DeviceGuard omitted
11101067
auto dispatch_scalar_type = infer_scalar_type(self);

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ extern "C" void cpotrf_(char *uplo, int *n, std::complex<float> *a, int *lda, in
4949
extern "C" void dpotrf_(char *uplo, int *n, double *a, int *lda, int *info);
5050
extern "C" void spotrf_(char *uplo, int *n, float *a, int *lda, int *info);
5151

52+
// potri
53+
extern "C" void zpotri_(char *uplo, int *n, std::complex<double> *a, int *lda, int *info);
54+
extern "C" void cpotri_(char *uplo, int *n, std::complex<float> *a, int *lda, int *info);
55+
extern "C" void dpotri_(char *uplo, int *n, double *a, int *lda, int *info);
56+
extern "C" void spotri_(char *uplo, int *n, float *a, int *lda, int *info);
57+
5258
// trtrs
5359
extern "C" void ztrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info);
5460
extern "C" void ctrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info);
@@ -237,6 +243,22 @@ template<> void lapackCholesky<float>(char uplo, int n, float *a, int lda, int *
237243
spotrf_(&uplo, &n, a, &lda, info);
238244
}
239245

246+
template<> void lapackCholeskyInverse<c10::complex<double>>(char uplo, int n, c10::complex<double> *a, int lda, int *info) {
247+
zpotri_(&uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, info);
248+
}
249+
250+
template<> void lapackCholeskyInverse<c10::complex<float>>(char uplo, int n, c10::complex<float> *a, int lda, int *info) {
251+
cpotri_(&uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, info);
252+
}
253+
254+
template<> void lapackCholeskyInverse<double>(char uplo, int n, double *a, int lda, int *info) {
255+
dpotri_(&uplo, &n, a, &lda, info);
256+
}
257+
258+
template<> void lapackCholeskyInverse<float>(char uplo, int n, float *a, int lda, int *info) {
259+
spotri_(&uplo, &n, a, &lda, info);
260+
}
261+
240262
template<> void lapackTriangularSolve<c10::complex<double>>(char uplo, char trans, char diag, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb, int *info) {
241263
ztrtrs_(&uplo, &trans, &diag, &n, &nrhs, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(b), &ldb, info);
242264
}
@@ -411,7 +433,7 @@ Computes the solution to a system of linear equations
411433
where A is an n-by-n matrix and X and B are n-by-nrhs matrices.
412434
Note that B is required to be a matrix, the usual, vector case, is obtained with nrhs = 1.
413435
Above description is for non-batched input, the batched input is also supported.
414-
This is an in-place routine, content of both A and b are overriden.
436+
This is an in-place routine, content of both A and b are overwritten.
415437
'infos' is an int Tensor containing error codes for each matrix in the batched input.
416438
For more information see LAPACK's documentation for GESV routine.
417439
*/
@@ -480,7 +502,7 @@ std::tuple<Tensor&,Tensor&> solve_out(Tensor& solution, Tensor& lu, const Tensor
480502
// This is a type dispatching helper function for 'apply_solve'
481503
Tensor& _linalg_solve_out_helper_cpu(Tensor& result, Tensor& input, Tensor& infos) {
482504
// 'result' and 'input' should be in column major order (it should be checked before calling this function)
483-
// the content of 'result', 'input' and 'infos' is overriden by 'apply_solve'
505+
// the content of 'result', 'input' and 'infos' is overwritten by 'apply_solve'
484506
// 'result' should contain data of 'other' tensor (right-hand-side of the linear system of equations)
485507
// 'input' should contain data of original 'input' tensor (left-hand-side of the linear system of equations)
486508
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "linalg_solve_out_cpu", [&]{
@@ -861,6 +883,78 @@ Tensor& linalg_cholesky_out(Tensor &result, const Tensor &self) {
861883
return result;
862884
}
863885

886+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
887+
888+
DEFINE_DISPATCH(cholesky_inverse_stub);
889+
890+
Tensor& cholesky_inverse_out_info(Tensor& result, Tensor& infos, const Tensor& input, bool upper) {
891+
TORCH_INTERNAL_ASSERT(input.dim() >= 2);
892+
TORCH_INTERNAL_ASSERT(input.size(-1) == input.size(-2));
893+
894+
TORCH_INTERNAL_ASSERT(result.scalar_type() == input.scalar_type());
895+
TORCH_INTERNAL_ASSERT(result.device() == input.device());
896+
897+
TORCH_INTERNAL_ASSERT(infos.scalar_type() == at::kInt);
898+
TORCH_INTERNAL_ASSERT(infos.device() == at::kCPU);
899+
TORCH_INTERNAL_ASSERT(infos.numel() == std::max<int64_t>(1, batchCount(input)));
900+
901+
// if result has no elements we can modify it
902+
if (result.numel() == 0) {
903+
at::native::resize_as_(result, input.transpose(-2, -1), MemoryFormat::Contiguous);
904+
result.transpose_(-2, -1);
905+
}
906+
907+
// result tensor must be in batched column major order (Fortran contiguous)
908+
TORCH_INTERNAL_ASSERT(result.transpose(-2, -1).is_contiguous());
909+
TORCH_INTERNAL_ASSERT(result.sizes().equals(input.sizes()));
910+
911+
// cholesky_inverse_stub (apply_cholesky_inverse) performs calculations in-place and result must be a copy of input
912+
result.copy_(input);
913+
914+
// infos must be contiguous
915+
TORCH_INTERNAL_ASSERT(infos.is_contiguous());
916+
infos.fill_(0);
917+
918+
result = cholesky_inverse_stub(result.device().type(), result, infos, upper);
919+
return result;
920+
}
921+
922+
Tensor& cholesky_inverse_out(const Tensor &input, bool upper, Tensor &result) {
923+
squareCheckInputs(input);
924+
TORCH_CHECK(result.scalar_type() == input.scalar_type(),
925+
"result dtype ", result.scalar_type(), " does not match input dtype ", input.scalar_type());
926+
TORCH_CHECK(result.device() == input.device(),
927+
"result device ", result.device(), " does not match input device ", input.device());
928+
929+
// MAGMA requires 'infos' to reside in CPU memory, therefore we create 'infos' only on CPU for now.
930+
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt).device(kCPU));
931+
932+
// if result is not empty and not in batched column major format we have to allocate a temporary tensor
933+
if (result.numel() != 0 && !result.transpose(-2, -1).is_contiguous()) {
934+
Tensor result_tmp = at::empty({0}, input.options());
935+
result_tmp = cholesky_inverse_out_info(result_tmp, infos, input, upper);
936+
at::native::resize_output(result, result_tmp.sizes());
937+
result.copy_(result_tmp);
938+
} else {
939+
// use result's memory directly
940+
result = cholesky_inverse_out_info(result, infos, input, upper);
941+
}
942+
943+
// Now check LAPACK/MAGMA error codes
944+
if (result.dim() > 2) {
945+
batchCheckErrors(infos, "cholesky_inverse");
946+
} else {
947+
singleCheckErrors(infos.item().toInt(), "cholesky_inverse");
948+
}
949+
return result;
950+
}
951+
952+
Tensor cholesky_inverse(const Tensor &input, bool upper) {
953+
Tensor result = at::empty({0}, input.options());
954+
result = at::cholesky_inverse_out(result, input, upper);
955+
return result;
956+
}
957+
864958
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
865959

866960
template<typename scalar_t>
@@ -1230,7 +1324,7 @@ Tensor orgqr(const Tensor& input, const Tensor& tau) {
12301324
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
12311325

12321326
// This function computes eigenvalues 'w' and eigenvectors 'v' of the input that is stored initially in 'v'
1233-
// The computation is done in-place: 'v' stores the input and will be overriden, 'w' should be an allocated empty array
1327+
// The computation is done in-place: 'v' stores the input and will be overwritten, 'w' should be an allocated empty array
12341328
// compute_v controls whether eigenvectors should be computed
12351329
// uplo_str controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
12361330
// infos is used to store information for possible checks for error

aten/src/ATen/native/BatchLinearAlgebra.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ namespace at { namespace native {
1414
// Define per-batch functions to be used in the implementation of batched
1515
// linear algebra operations
1616

17+
template<class scalar_t>
18+
void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
19+
1720
template<class scalar_t, class value_t=scalar_t>
1821
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
1922

@@ -22,6 +25,10 @@ void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scala
2225

2326
#endif
2427

28+
using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
29+
30+
DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
31+
2532
using eig_fn = std::tuple<Tensor, Tensor> (*)(const Tensor&, bool&);
2633

2734
DECLARE_DISPATCH(eig_fn, eig_stub);

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,79 @@ namespace at { namespace native {
1010

1111
namespace {
1212

13+
/*
14+
Copies the lower (or upper) triangle of the square matrix to the other half and conjugates it.
15+
This operation is performed in-place.
16+
*/
17+
template <typename scalar_t>
18+
void apply_reflect_conj_tri_single(scalar_t* self, int64_t n, int64_t stride, bool upper) {
19+
std::function<void(int64_t, int64_t)> loop = [](int64_t, int64_t){};
20+
if (upper) {
21+
loop = [&](int64_t start, int64_t end) {
22+
for (int64_t i = start; i < end; i++) {
23+
for (int64_t j = i + 1; j < n; j++) {
24+
self[i * stride + j] = conj_impl(self[j * stride + i]);
25+
}
26+
}
27+
};
28+
} else {
29+
loop = [&](int64_t start, int64_t end) {
30+
for (int64_t i = start; i < end; i++) {
31+
for (int64_t j = 0; j < i; j++) {
32+
self[i * stride + j] = conj_impl(self[j * stride + i]);
33+
}
34+
}
35+
};
36+
}
37+
// For small matrices OpenMP overhead is too large
38+
if (n < 256) {
39+
loop(0, n);
40+
} else {
41+
at::parallel_for(0, n, 0, loop);
42+
}
43+
}
44+
45+
/*
46+
Computes the inverse of a symmetric (Hermitian) positive-definite matrix n-by-n matrix 'input' using the Cholesky factorization
47+
This is an in-place routine, content of 'input' is overwritten.
48+
'infos' is an int Tensor containing error codes for each matrix in the batched input.
49+
For more information see LAPACK's documentation for POTRI routine.
50+
*/
51+
template <typename scalar_t>
52+
void apply_cholesky_inverse(Tensor& input, Tensor& infos, bool upper) {
53+
#ifndef USE_LAPACK
54+
TORCH_CHECK(false, "cholesky_inverse: LAPACK library not found in compilation");
55+
#else
56+
char uplo = upper ? 'U' : 'L';
57+
58+
auto input_data = input.data_ptr<scalar_t>();
59+
auto infos_data = infos.data_ptr<int>();
60+
auto input_matrix_stride = matrixStride(input);
61+
auto batch_size = batchCount(input);
62+
auto n = input.size(-2);
63+
auto lda = std::max<int64_t>(1, n);
64+
65+
for (int64_t i = 0; i < batch_size; i++) {
66+
scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
67+
int* info_working_ptr = &infos_data[i];
68+
lapackCholeskyInverse<scalar_t>(uplo, n, input_working_ptr, lda, info_working_ptr);
69+
// LAPACK writes to only upper/lower part of the matrix leaving the other side unchanged
70+
apply_reflect_conj_tri_single<scalar_t>(input_working_ptr, n, lda, upper);
71+
}
72+
#endif
73+
}
74+
75+
// This is a type dispatching helper function for 'apply_cholesky_inverse'
76+
Tensor& cholesky_inverse_kernel_impl(Tensor& result, Tensor& infos, bool upper) {
77+
// This function calculates the inverse matrix in-place
78+
// result should be in column major order and contain matrices to invert
79+
// the content of result is overwritten by 'apply_cholesky_inverse'
80+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "cholesky_inverse_out_cpu", [&]{
81+
apply_cholesky_inverse<scalar_t>(result, infos, upper);
82+
});
83+
return result;
84+
}
85+
1386
template <typename scalar_t>
1487
void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vecs_, int64_t* info_ptr) {
1588
#ifndef USE_LAPACK
@@ -98,6 +171,10 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau, Tensor& infos, int6
98171

99172
} // anonymous namespace
100173

174+
REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);
175+
REGISTER_AVX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
176+
REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
177+
101178
REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl);
102179
REGISTER_AVX_DISPATCH(eig_stub, &eig_kernel_impl);
103180
REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl);

0 commit comments

Comments
 (0)