Skip to content

Commit caf96e4

Browse files
committed
Refactor the generator
1 parent a27f64e commit caf96e4

File tree

2 files changed

+52
-55
lines changed

2 files changed

+52
-55
lines changed

bin/function.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
return_re = re.compile('(?:\s*->\s*([^;]+))?')
66

77

8-
class Function(object):
8+
class Function():
99

1010
def __init__(self, name, args, ret):
1111
self.name = name
@@ -25,8 +25,6 @@ def parse(line):
2525
arg, aty, line = pull_argument(line)
2626
if arg is None:
2727
break
28-
if arg == 'matrix_layout':
29-
arg = 'layout'
3028
args.append((arg, aty))
3129
line = line.strip()
3230

@@ -55,7 +53,7 @@ def pull_return(s):
5553
return match.group(1), s[match.end(1):]
5654

5755

58-
def read_functions(path):
56+
def read(path):
5957
lines = []
6058
with open(path) as file:
6159
append = False

bin/generate.py

+50-51
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import re
66

77
from function import Function
8-
from function import read_functions
8+
from function import read
99

1010
select_re = re.compile('LAPACK_(\w)_SELECT(\d)')
1111

1212

1313
def is_scalar(name, cty, f):
14-
return (
14+
return ( \
1515
'c_char' in cty or
1616
name in [
1717
'abnrm',
@@ -73,7 +73,33 @@ def is_scalar(name, cty, f):
7373
)
7474

7575

76-
def translate_argument(name, cty, f):
76+
def translate_name(name):
77+
mapping = {
78+
'matrix_layout': 'layout',
79+
}
80+
return mapping.get(name, name)
81+
82+
83+
def translate_base_type(cty):
84+
if 'c_char' in cty:
85+
return 'u8'
86+
elif 'c_float' in cty:
87+
return 'f32'
88+
elif 'c_double' in cty:
89+
return 'f64'
90+
elif 'lapack_int' in cty:
91+
return 'i32'
92+
elif 'lapack_logical' in cty:
93+
return 'i32'
94+
elif 'lapack_complex_float' in cty:
95+
return 'c32'
96+
elif 'lapack_complex_double' in cty:
97+
return 'c64'
98+
99+
assert False, 'cannot translate `{}`'.format(cty)
100+
101+
102+
def translate_signature_type(name, cty, f):
77103
if name == 'layout':
78104
return 'Layout'
79105

@@ -88,7 +114,7 @@ def translate_argument(name, cty, f):
88114
elif m.group(1) == 'Z':
89115
return 'Select{}C64'.format(m.group(2))
90116

91-
base = translate_type_base(cty)
117+
base = translate_base_type(cty)
92118
if '*const' in cty:
93119
if is_scalar(name, cty, f):
94120
return base
@@ -103,25 +129,6 @@ def translate_argument(name, cty, f):
103129
return base
104130

105131

106-
def translate_type_base(cty):
107-
if 'c_char' in cty:
108-
return 'u8'
109-
elif 'c_float' in cty:
110-
return 'f32'
111-
elif 'c_double' in cty:
112-
return 'f64'
113-
elif 'lapack_int' in cty:
114-
return 'i32'
115-
elif 'lapack_logical' in cty:
116-
return 'i32'
117-
elif 'lapack_complex_float' in cty:
118-
return 'c32'
119-
elif 'lapack_complex_double' in cty:
120-
return 'c64'
121-
122-
assert False, 'cannot translate `{}`'.format(cty)
123-
124-
125132
def translate_body_argument(name, rty):
126133
if rty == 'Layout':
127134
return '{}.into()'.format(name)
@@ -164,56 +171,48 @@ def translate_body_argument(name, rty):
164171
assert False, 'cannot translate `{}: {}`'.format(name, rty)
165172

166173

167-
def translate_return_type(cty):
168-
if cty == 'lapack_int':
169-
return 'i32'
170-
elif cty == 'c_float':
171-
return 'f32'
172-
elif cty == 'c_double':
173-
return 'f64'
174-
175-
assert False, 'cannot translate `{}`'.format(cty)
176-
177-
178-
def format_header(f):
179-
args = format_header_arguments(f)
174+
def format_signature(f):
175+
args = format_signature_arguments(f)
180176
if f.ret is None:
181177
return 'pub unsafe fn {}({})'.format(f.name, args)
182178
else:
183-
return 'pub unsafe fn {}({}) -> {}'.format(f.name, args, translate_return_type(f.ret))
179+
return 'pub unsafe fn {}({}) -> {}'.format(f.name, args,
180+
translate_base_type(f.ret))
184181

185182

186-
def format_body(f):
187-
return 'ffi::LAPACKE_{}({})'.format(f.name, format_body_arguments(f))
188-
189-
190-
def format_header_arguments(f):
183+
def format_signature_arguments(f):
191184
s = []
192-
for arg in f.args:
193-
s.append('{}: {}'.format(arg[0], translate_argument(*arg, f=f)))
185+
for name, cty in f.args:
186+
name = translate_name(name)
187+
s.append('{}: {}'.format(name, translate_signature_type(name, cty, f)))
194188
return ', '.join(s)
195189

196190

191+
def format_body(f):
192+
return 'ffi::LAPACKE_{}({})'.format(f.name, format_body_arguments(f))
193+
194+
197195
def format_body_arguments(f):
198196
s = []
199-
for arg in f.args:
200-
rty = translate_argument(*arg, f=f)
201-
s.append(translate_body_argument(arg[0], rty))
197+
for name, cty in f.args:
198+
name = translate_name(name)
199+
rty = translate_signature_type(name, cty, f)
200+
s.append(translate_body_argument(name, rty))
202201
return ', '.join(s)
203202

204203

205-
def prepare(code):
204+
def process(code):
206205
lines = filter(lambda line: not re.match(r'^\s*//.*', line),
207206
code.split('\n'))
208207
lines = re.sub(r'\s+', ' ', ''.join(lines)).strip().split(';')
209208
lines = filter(lambda line: not re.match(r'^\s*$', line), lines)
210209
return [Function.parse(line) for line in lines]
211210

212211

213-
def do(functions):
212+
def write(functions):
214213
for f in functions:
215214
print('\n#[inline]')
216-
print(format_header(f) + ' {')
215+
print(format_signature(f) + ' {')
217216
print(' ' + format_body(f) + '\n}')
218217

219218

@@ -222,4 +221,4 @@ def do(functions):
222221
parser.add_argument('--sys', default='lapacke-sys')
223222
arguments = parser.parse_args()
224223
path = os.path.join(arguments.sys, 'src', 'lib.rs')
225-
do(prepare(read_functions(path)))
224+
write(process(read(path)))

0 commit comments

Comments
 (0)