Skip to content

Commit d101a1f

Browse files
committed
Adding airthmetic operations
1 parent 9cbe920 commit d101a1f

File tree

5 files changed

+209
-13
lines changed

5 files changed

+209
-13
lines changed

arrayfire.lua

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ require('arrayfire.defines')
33
require('arrayfire.dim4')
44
require('arrayfire.util')
55
require('arrayfire.array')
6+
require('arrayfire.arith')
67
require('arrayfire.device')
78

89
return af

arrayfire/arith.lua

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
require('arrayfire.lib')
2+
require('arrayfire.defines')
3+
require('arrayfire.array')
4+
local ffi = require( "ffi" )
5+
6+
local funcs = {}
7+
8+
funcs[30] = [[
9+
af_err af_add (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
10+
af_err af_sub (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
11+
af_err af_mul (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
12+
af_err af_div (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
13+
af_err af_lt (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
14+
af_err af_gt (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
15+
af_err af_le (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
16+
af_err af_ge (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
17+
af_err af_eq (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
18+
af_err af_neq (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
19+
af_err af_and (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
20+
af_err af_or (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
21+
af_err af_not (af_array *out, const af_array in);
22+
af_err af_bitand (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
23+
af_err af_bitor (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
24+
af_err af_bitxor (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
25+
af_err af_bitshiftl(af_array *out, const af_array lhs, const af_array rhs, const bool batch);
26+
af_err af_bitshiftr(af_array *out, const af_array lhs, const af_array rhs, const bool batch);
27+
af_err af_cast (af_array *out, const af_array in, const af_dtype type);
28+
af_err af_minof (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
29+
af_err af_maxof (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
30+
af_err af_rem (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
31+
af_err af_mod (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
32+
af_err af_abs (af_array *out, const af_array in);
33+
af_err af_arg (af_array *out, const af_array in);
34+
af_err af_sign (af_array *out, const af_array in);
35+
af_err af_round (af_array *out, const af_array in);
36+
af_err af_trunc (af_array *out, const af_array in);
37+
af_err af_floor (af_array *out, const af_array in);
38+
af_err af_ceil (af_array *out, const af_array in);
39+
af_err af_hypot (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
40+
af_err af_sin (af_array *out, const af_array in);
41+
af_err af_cos (af_array *out, const af_array in);
42+
af_err af_tan (af_array *out, const af_array in);
43+
af_err af_asin (af_array *out, const af_array in);
44+
af_err af_acos (af_array *out, const af_array in);
45+
af_err af_atan (af_array *out, const af_array in);
46+
af_err af_atan2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
47+
af_err af_cplx2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
48+
af_err af_cplx (af_array *out, const af_array in);
49+
af_err af_real (af_array *out, const af_array in);
50+
af_err af_imag (af_array *out, const af_array in);
51+
af_err af_conjg (af_array *out, const af_array in);
52+
af_err af_sinh (af_array *out, const af_array in);
53+
af_err af_cosh (af_array *out, const af_array in);
54+
af_err af_tanh (af_array *out, const af_array in);
55+
af_err af_asinh (af_array *out, const af_array in);
56+
af_err af_acosh (af_array *out, const af_array in);
57+
af_err af_atanh (af_array *out, const af_array in);
58+
af_err af_root (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
59+
af_err af_pow (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
60+
af_err af_pow2 (af_array *out, const af_array in);
61+
af_err af_exp (af_array *out, const af_array in);
62+
63+
af_err af_expm1 (af_array *out, const af_array in);
64+
af_err af_erf (af_array *out, const af_array in);
65+
af_err af_erfc (af_array *out, const af_array in);
66+
af_err af_log (af_array *out, const af_array in);
67+
af_err af_log1p (af_array *out, const af_array in);
68+
af_err af_log10 (af_array *out, const af_array in);
69+
af_err af_log2 (af_array *out, const af_array in);
70+
af_err af_sqrt (af_array *out, const af_array in);
71+
af_err af_cbrt (af_array *out, const af_array in);
72+
af_err af_factorial (af_array *out, const af_array in);
73+
af_err af_tgamma (af_array *out, const af_array in);
74+
af_err af_lgamma (af_array *out, const af_array in);
75+
af_err af_iszero (af_array *out, const af_array in);
76+
af_err af_isinf (af_array *out, const af_array in);
77+
af_err af_isnan (af_array *out, const af_array in);
78+
]]
79+
80+
funcs[31] = [[
81+
af_err af_sigmoid (af_array *out, const af_array in);
82+
]]
83+
84+
funcs[34] = [[
85+
af_err af_clamp(af_array *out, const af_array in,
86+
const af_array lo, const af_array hi, const bool batch);
87+
]]
88+
89+
af.lib.cdef(funcs)
90+
local c_array_p = af.ffi.c_array_p
91+
local init = af.Array.init
92+
93+
local binaryFuncs = {
94+
'add',
95+
'sub',
96+
'mul',
97+
'div',
98+
'lt',
99+
'gt',
100+
'le',
101+
'ge',
102+
'eq',
103+
'neq',
104+
'and',
105+
'or',
106+
'bitand',
107+
'bitor',
108+
'bitxor',
109+
'bitshiftl',
110+
'bitshiftr',
111+
'minof',
112+
'maxof',
113+
'rem',
114+
'mod',
115+
'hypot',
116+
'atan2',
117+
'cplx2',
118+
'root',
119+
'pow',
120+
}
121+
122+
123+
for _, func in ipairs(binaryFuncs) do
124+
af[func] = function(lhs, rhs, batch)
125+
-- TODO: add support for numbers
126+
-- TODO: add support for batch mode
127+
local res = c_array_p()
128+
af.clib['af_' .. func](res, lhs:get(), rhs:get(), batch and true or false)
129+
return init(res[0])
130+
end
131+
end
132+
133+
local unaryFuncs = {
134+
'abs',
135+
'arg',
136+
'sign',
137+
'round',
138+
'trunc',
139+
'floor',
140+
'ceil',
141+
'sin',
142+
'cos',
143+
'tan',
144+
'asin',
145+
'acos',
146+
'atan',
147+
'cplx',
148+
'real',
149+
'imag',
150+
'conjg',
151+
'sinh',
152+
'cosh',
153+
'tanh',
154+
'asinh',
155+
'acosh',
156+
'atanh',
157+
'pow2',
158+
'exp',
159+
'expm1',
160+
'erf',
161+
'erfc',
162+
'log',
163+
'log1p',
164+
'log10',
165+
'log2',
166+
'sqrt',
167+
'cbrt',
168+
'factorial',
169+
'tgamma',
170+
'lgamma',
171+
'iszero',
172+
'isinf',
173+
'isnan'
174+
}
175+
176+
for _, func in ipairs(unaryFuncs) do
177+
af[func] = function(input)
178+
-- TODO: add support for numbers
179+
-- TODO: add support for batch mode
180+
local res = c_array_p()
181+
af.clib['af_' .. func](res, input:get())
182+
return init(res[0])
183+
end
184+
end
185+
186+
af.cast = function(input, rtype)
187+
local res = c_array_p()
188+
af.clib.af_cast(res, input:get(), rtype)
189+
return init(res[0])
190+
end

arrayfire/array.lua

+16-12
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,19 @@ local c_uint_t = af.ffi.c_uint_t
7171
local c_ptr_t = af.ffi.c_ptr_t
7272
local Dim4 = af.Dim4
7373

74+
function release_array(ptr)
75+
local res = af.clib.af_release_array(ptr)
76+
-- TODO: Error handling logic
77+
end
78+
7479
local c_array_p = function(ptr)
7580
local arr_ptr = ffi.new('void *[1]', ptr)
76-
arr_ptr[0] = ffi.gc(arr_ptr[0], af.clib.af_release_array)
7781
return arr_ptr
7882
end
7983

8084
local init = function(ptr)
8185
local self = setmetatable({}, Array)
82-
self._array = ptr
86+
self._ptr = ffi.gc(ptr, release_array)
8387
return self
8488
end
8589

@@ -117,51 +121,51 @@ Array.__tostring = function(self)
117121
end
118122

119123
Array.get = function(self)
120-
return self._array
124+
return self._ptr
121125
end
122126

123127
-- TODO: implement Array.write
124128

125129
Array.copy = function(self)
126130
local res = c_array_p()
127-
af.clib.af_copy_array(res, self._array)
131+
af.clib.af_copy_array(res, self:get())
128132
return Array.init(res[0])
129133
end
130134

131135
Array.softCopy = function(self)
132136
local res = c_array_p()
133-
af.clib.af_copy_array(res, self._array)
137+
af.clib.af_copy_array(res, self:get())
134138
return Array.init(res[0])
135139
end
136140

137141
Array.elements = function(self)
138142
local res = c_ptr_t('dim_t')
139-
af.clib.af_get_elements(res, self._array)
143+
af.clib.af_get_elements(res, self:get())
140144
return tonumber(res[0])
141145
end
142146

143147
Array.type = function(self)
144148
local res = c_ptr_t('af_dtype')
145-
af.clib.af_get_type(res, self._array)
149+
af.clib.af_get_type(res, self:get())
146150
return tonumber(res[0])
147151
end
148152

149153
Array.typeName = function(self)
150154
local res = c_ptr_t('af_dtype')
151-
af.clib.af_get_type(res, self._array)
155+
af.clib.af_get_type(res, self:get())
152156
return af.dtype_names[tonumber(res[0])]
153157
end
154158

155159
Array.dims = function(self)
156160
local res = c_dim4_t()
157-
af.clib.af_get_dims(res + 0, res + 1, res + 2, res + 3, self._array)
161+
af.clib.af_get_dims(res + 0, res + 1, res + 2, res + 3, self:get())
158162
return Dim4(tonumber(res[0]), tonumber(res[1]),
159163
tonumber(res[2]), tonumber(res[3]))
160164
end
161165

162166
Array.numdims = function(self)
163167
local res = c_ptr_t('unsigned int')
164-
af.clib.af_get_numdims(res, self._array)
168+
af.clib.af_get_numdims(res, self:get())
165169
return tonumber(res[0])
166170
end
167171

@@ -184,13 +188,13 @@ local funcs = {
184188
for name, cname in pairs(funcs) do
185189
Array[name] = function(self)
186190
local res = c_ptr_t('bool')
187-
af.clib['af_' .. cname](res, self._array)
191+
af.clib['af_' .. cname](res, self:get())
188192
return res[0]
189193
end
190194
end
191195

192196
Array.eval = function(self)
193-
af.clib.af_eval(self._array)
197+
af.clib.af_eval(self:get())
194198
end
195199

196200
-- Useful aliases

arrayfire/util.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ funcs[34] = [[
2929
af.lib.cdef(funcs)
3030

3131
af.print = function(arr)
32-
af.clib.af_print_array_gen(ffi.cast("char *", "ArrayFire Array"), arr._array, 4)
32+
af.clib.af_print_array_gen(ffi.cast("char *", "ArrayFire Array"), arr:get(), 4)
3333
end

rocks/arrayfire-scm-1.rockspec

+1
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ build = {
2323
["arrayfire.defines"] = "arrayfire/defines.lua",
2424
["arrayfire.device"] = "arrayfire/device.lua",
2525
["arrayfire.dim4"] = "arrayfire/dim4.lua",
26+
["arrayfire.arith"] = "arrayfire/arith.lua",
2627
},
2728
}

0 commit comments

Comments
 (0)