5
5
import re
6
6
7
7
from function import Function
8
- from function import read_functions
8
+ from function import read
9
9
10
10
select_re = re .compile ('LAPACK_(\w)_SELECT(\d)' )
11
11
12
12
13
13
def is_scalar (name , cty , f ):
14
- return (
14
+ return ( \
15
15
'c_char' in cty or
16
16
name in [
17
17
'abnrm' ,
@@ -73,7 +73,33 @@ def is_scalar(name, cty, f):
73
73
)
74
74
75
75
76
- def translate_argument (name , cty , f ):
76
+ def translate_name (name ):
77
+ mapping = {
78
+ 'matrix_layout' : 'layout' ,
79
+ }
80
+ return mapping .get (name , name )
81
+
82
+
83
+ def translate_base_type (cty ):
84
+ if 'c_char' in cty :
85
+ return 'u8'
86
+ elif 'c_float' in cty :
87
+ return 'f32'
88
+ elif 'c_double' in cty :
89
+ return 'f64'
90
+ elif 'lapack_int' in cty :
91
+ return 'i32'
92
+ elif 'lapack_logical' in cty :
93
+ return 'i32'
94
+ elif 'lapack_complex_float' in cty :
95
+ return 'c32'
96
+ elif 'lapack_complex_double' in cty :
97
+ return 'c64'
98
+
99
+ assert False , 'cannot translate `{}`' .format (cty )
100
+
101
+
102
+ def translate_signature_type (name , cty , f ):
77
103
if name == 'layout' :
78
104
return 'Layout'
79
105
@@ -88,7 +114,7 @@ def translate_argument(name, cty, f):
88
114
elif m .group (1 ) == 'Z' :
89
115
return 'Select{}C64' .format (m .group (2 ))
90
116
91
- base = translate_type_base (cty )
117
+ base = translate_base_type (cty )
92
118
if '*const' in cty :
93
119
if is_scalar (name , cty , f ):
94
120
return base
@@ -103,25 +129,6 @@ def translate_argument(name, cty, f):
103
129
return base
104
130
105
131
106
- def translate_type_base (cty ):
107
- if 'c_char' in cty :
108
- return 'u8'
109
- elif 'c_float' in cty :
110
- return 'f32'
111
- elif 'c_double' in cty :
112
- return 'f64'
113
- elif 'lapack_int' in cty :
114
- return 'i32'
115
- elif 'lapack_logical' in cty :
116
- return 'i32'
117
- elif 'lapack_complex_float' in cty :
118
- return 'c32'
119
- elif 'lapack_complex_double' in cty :
120
- return 'c64'
121
-
122
- assert False , 'cannot translate `{}`' .format (cty )
123
-
124
-
125
132
def translate_body_argument (name , rty ):
126
133
if rty == 'Layout' :
127
134
return '{}.into()' .format (name )
@@ -164,56 +171,48 @@ def translate_body_argument(name, rty):
164
171
assert False , 'cannot translate `{}: {}`' .format (name , rty )
165
172
166
173
167
- def translate_return_type (cty ):
168
- if cty == 'lapack_int' :
169
- return 'i32'
170
- elif cty == 'c_float' :
171
- return 'f32'
172
- elif cty == 'c_double' :
173
- return 'f64'
174
-
175
- assert False , 'cannot translate `{}`' .format (cty )
176
-
177
-
178
- def format_header (f ):
179
- args = format_header_arguments (f )
174
+ def format_signature (f ):
175
+ args = format_signature_arguments (f )
180
176
if f .ret is None :
181
177
return 'pub unsafe fn {}({})' .format (f .name , args )
182
178
else :
183
- return 'pub unsafe fn {}({}) -> {}' .format (f .name , args , translate_return_type (f .ret ))
179
+ return 'pub unsafe fn {}({}) -> {}' .format (f .name , args ,
180
+ translate_base_type (f .ret ))
184
181
185
182
186
- def format_body (f ):
187
- return 'ffi::LAPACKE_{}({})' .format (f .name , format_body_arguments (f ))
188
-
189
-
190
- def format_header_arguments (f ):
183
+ def format_signature_arguments (f ):
191
184
s = []
192
- for arg in f .args :
193
- s .append ('{}: {}' .format (arg [0 ], translate_argument (* arg , f = f )))
185
+ for name , cty in f .args :
186
+ name = translate_name (name )
187
+ s .append ('{}: {}' .format (name , translate_signature_type (name , cty , f )))
194
188
return ', ' .join (s )
195
189
196
190
191
+ def format_body (f ):
192
+ return 'ffi::LAPACKE_{}({})' .format (f .name , format_body_arguments (f ))
193
+
194
+
197
195
def format_body_arguments (f ):
198
196
s = []
199
- for arg in f .args :
200
- rty = translate_argument (* arg , f = f )
201
- s .append (translate_body_argument (arg [0 ], rty ))
197
+ for name , cty in f .args :
198
+ name = translate_name (name )
199
+ rty = translate_signature_type (name , cty , f )
200
+ s .append (translate_body_argument (name , rty ))
202
201
return ', ' .join (s )
203
202
204
203
205
- def prepare (code ):
204
+ def process (code ):
206
205
lines = filter (lambda line : not re .match (r'^\s*//.*' , line ),
207
206
code .split ('\n ' ))
208
207
lines = re .sub (r'\s+' , ' ' , '' .join (lines )).strip ().split (';' )
209
208
lines = filter (lambda line : not re .match (r'^\s*$' , line ), lines )
210
209
return [Function .parse (line ) for line in lines ]
211
210
212
211
213
- def do (functions ):
212
+ def write (functions ):
214
213
for f in functions :
215
214
print ('\n #[inline]' )
216
- print (format_header (f ) + ' {' )
215
+ print (format_signature (f ) + ' {' )
217
216
print (' ' + format_body (f ) + '\n }' )
218
217
219
218
@@ -222,4 +221,4 @@ def do(functions):
222
221
parser .add_argument ('--sys' , default = 'lapacke-sys' )
223
222
arguments = parser .parse_args ()
224
223
path = os .path .join (arguments .sys , 'src' , 'lib.rs' )
225
- do ( prepare ( read_functions (path )))
224
+ write ( process ( read (path )))
0 commit comments