1+ #include " BlasLapackInterface.hpp"
2+ #include " backendHelpers.hpp"
3+
4+ extern " C"
5+ {
6+
7+ // //////////////////////////////////////////////////////////////////////////////////////////////////
8+ #define LVARRAY_SGEMM LVARRAY_LAPACK_FORTRAN_MANGLE ( sgemm )
9+ void LVARRAY_SGEMM(
10+ char const * TRANSA,
11+ char const * TRANSB,
12+ int const * M,
13+ int const * N,
14+ int const * K,
15+ float const * ALPHA,
16+ float const * A,
17+ int const * LDA,
18+ float const * B,
19+ int const * LDB,
20+ float const * BETA,
21+ float * C,
22+ int const * LDC );
23+
24+ // //////////////////////////////////////////////////////////////////////////////////////////////////
25+ #define LVARRAY_DGEMM LVARRAY_LAPACK_FORTRAN_MANGLE ( dgemm )
26+ void LVARRAY_DGEMM(
27+ char const * TRANSA,
28+ char const * TRANSB,
29+ int const * M,
30+ int const * N,
31+ int const * K,
32+ double const * ALPHA,
33+ double const * A,
34+ int const * LDA,
35+ double const * B,
36+ int const * LDB,
37+ double const * BETA,
38+ double * C,
39+ int const * LDC );
40+
41+ // //////////////////////////////////////////////////////////////////////////////////////////////////
42+ #define LVARRAY_CGEMM LVARRAY_LAPACK_FORTRAN_MANGLE ( cgemm )
43+ void LVARRAY_CGEMM(
44+ char const * TRANSA,
45+ char const * TRANSB,
46+ int const * M,
47+ int const * N,
48+ int const * K,
49+ std::complex < float > const * ALPHA,
50+ std::complex < float > const * A,
51+ int const * LDA,
52+ std::complex < float > const * B,
53+ int const * LDB,
54+ std::complex < float > const * BETA,
55+ std::complex < float > * C,
56+ int const * LDC );
57+
58+ // //////////////////////////////////////////////////////////////////////////////////////////////////
59+ #define LVARRAY_ZGEMM LVARRAY_LAPACK_FORTRAN_MANGLE ( zgemm )
60+ void LVARRAY_ZGEMM(
61+ char const * TRANSA,
62+ char const * TRANSB,
63+ int const * M,
64+ int const * N,
65+ int const * K,
66+ std::complex < double > const * ALPHA,
67+ std::complex < double > const * A,
68+ int const * LDA,
69+ std::complex < double > const * B,
70+ int const * LDB,
71+ std::complex < double > const * BETA,
72+ std::complex < double > * C,
73+ int const * LDC );
74+
75+ // //////////////////////////////////////////////////////////////////////////////////////////////////
76+ #define LVARRAY_SGESV LVARRAY_LAPACK_FORTRAN_MANGLE ( sgesv )
77+ void LVARRAY_SGESV(
78+ int const * N,
79+ int const * NRHS,
80+ float * A,
81+ int const * LDA,
82+ int * IPIV,
83+ float * B,
84+ int const * LDB,
85+ int * INFO );
86+
87+ // //////////////////////////////////////////////////////////////////////////////////////////////////
88+ #define LVARRAY_DGESV LVARRAY_LAPACK_FORTRAN_MANGLE ( dgesv )
89+ void LVARRAY_DGESV(
90+ int const * N,
91+ int const * NRHS,
92+ double * A,
93+ int const * LDA,
94+ int * IPIV,
95+ double * B,
96+ int const * LDB,
97+ int * INFO );
98+
99+ // //////////////////////////////////////////////////////////////////////////////////////////////////
100+ #define LVARRAY_CGESV LVARRAY_LAPACK_FORTRAN_MANGLE ( cgesv )
101+ void LVARRAY_CGESV(
102+ int const * N,
103+ int const * NRHS,
104+ std::complex < float > * A,
105+ int const * LDA,
106+ int * IPIV,
107+ std::complex < float > * B,
108+ int const * LDB,
109+ int * INFO );
110+
111+ // //////////////////////////////////////////////////////////////////////////////////////////////////
112+ #define LVARRAY_ZGESV LVARRAY_LAPACK_FORTRAN_MANGLE ( zgesv )
113+ void LVARRAY_ZGESV(
114+ int const * N,
115+ int const * NRHS,
116+ std::complex < double > * A,
117+ int const * LDA,
118+ int * IPIV,
119+ std::complex < double > * B,
120+ int const * LDB,
121+ int * INFO );
122+
123+ } // extern "C"
124+
125+ namespace LvArray
126+ {
127+ namespace dense
128+ {
129+
130+ char toLapackChar ( Operation const op )
131+ {
132+ if ( op == Operation::NO_OP ) return ' N' ;
133+ if ( op == Operation::TRANSPOSE ) return ' T' ;
134+ if ( op == Operation::ADJOINT ) return ' C' ;
135+
136+ LVARRAY_ERROR ( " Unknown operation: " << int ( op ) );
137+ return ' \0 ' ;
138+ }
139+
140+
141+ template < typename T >
142+ void BlasLapackInterface< T >::gemm(
143+ Operation opA,
144+ Operation opB,
145+ T const alpha,
146+ Matrix< T const > const & A,
147+ Matrix< T const > const & B,
148+ T const beta,
149+ Matrix< T > const & C )
150+ {
151+ char const TRANSA = toLapackChar ( opA );
152+ char const TRANSB = toLapackChar ( opB );
153+ int const M = C.sizes [ 0 ];
154+ int const N = C.sizes [ 1 ];
155+ int const K = opA == Operation::NO_OP ? A.sizes [ 1 ] : A.sizes [ 0 ];
156+ int const LDA = std::max ( std::ptrdiff_t { 1 }, A.strides [ 1 ] );
157+ int const LDB = std::max ( std::ptrdiff_t { 1 }, B.strides [ 1 ] );
158+ int const LDC = std::max ( std::ptrdiff_t { 1 }, C.strides [ 1 ] );
159+
160+ TypeDispatch< T >::dispatch ( LVARRAY_SGEMM, LVARRAY_DGEMM, LVARRAY_CGEMM, LVARRAY_ZGEMM,
161+ &TRANSA,
162+ &TRANSB,
163+ &M,
164+ &N,
165+ &K,
166+ &alpha,
167+ A.data ,
168+ &LDA,
169+ B.data ,
170+ &LDB,
171+ &beta,
172+ C.data ,
173+ &LDC );
174+ }
175+
176+
177+ template < typename T >
178+ void BlasLapackInterface< T >::gesv(
179+ Matrix< T > const & A,
180+ Matrix< T > const & B,
181+ Vector< int > const & pivots )
182+ {
183+ int const N = A.sizes [ 0 ];
184+ int const NRHS = B.sizes [ 1 ];
185+ int const LDA = A.strides [ 1 ];
186+ int const LDB = B.strides [ 1 ];
187+ int INFO = 0 ;
188+
189+ TypeDispatch< T >::dispatch ( LVARRAY_SGESV, LVARRAY_DGESV, LVARRAY_CGESV, LVARRAY_ZGESV,
190+ &N,
191+ &NRHS,
192+ A.data ,
193+ &LDA,
194+ pivots.data ,
195+ B.data ,
196+ &LDB,
197+ &INFO );
198+
199+ LVARRAY_ERROR_IF ( INFO < 0 , " The " << -INFO << " -th argument had an illegal value." );
200+ LVARRAY_ERROR_IF ( INFO > 0 , " The factorization has been completed but U( " << INFO - 1 << " , " << INFO - 1 <<
201+ " ) is exactly zero so the solution could not be computed." );
202+ }
203+
204+ template class BlasLapackInterface < float >;
205+ template class BlasLapackInterface < double >;
206+ template class BlasLapackInterface < std::complex < float > >;
207+ template class BlasLapackInterface < std::complex < double > >;
208+
209+ } // namespace dense
210+ } // namespace LvArray
0 commit comments