Skip to content

Commit 9cbe920

Browse files
committed
Adding functions from array.h
1 parent 60e06eb commit 9cbe920

File tree

6 files changed

+145
-21
lines changed

6 files changed

+145
-21
lines changed

arrayfire.lua

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
require('arrayfire.lib')
22
require('arrayfire.defines')
3+
require('arrayfire.dim4')
34
require('arrayfire.util')
45
require('arrayfire.array')
56
require('arrayfire.device')

arrayfire/array.lua

+98-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
require('arrayfire.lib')
22
require('arrayfire.defines')
3+
require('arrayfire.dim4')
34
local ffi = require( "ffi" )
45

56
local funcs = {}
@@ -67,15 +68,22 @@ Array.__index = Array
6768

6869
local c_dim4_t = af.ffi.c_dim4_t
6970
local c_uint_t = af.ffi.c_uint_t
70-
local c_array_p = af.ffi.c_array_p
71+
local c_ptr_t = af.ffi.c_ptr_t
72+
local Dim4 = af.Dim4
7173

72-
local add_finalizer = function(arr_ptr)
73-
return ffi.gc(arr_ptr[0], af.clib.af_release_array)
74+
local c_array_p = function(ptr)
75+
local arr_ptr = ffi.new('void *[1]', ptr)
76+
arr_ptr[0] = ffi.gc(arr_ptr[0], af.clib.af_release_array)
77+
return arr_ptr
7478
end
7579

76-
Array.__init = function(data, dims, dtype, source)
80+
local init = function(ptr)
7781
local self = setmetatable({}, Array)
82+
self._array = ptr
83+
return self
84+
end
7885

86+
Array.__init = function(data, dims, dtype, source)
7987
if data then
8088
assert(af.istable(data))
8189
end
@@ -87,32 +95,109 @@ Array.__init = function(data, dims, dtype, source)
8795
c_dims = c_dim4_t(dims or (data and {#data} or {}))
8896
c_ndims = c_uint_t(dims and #dims or (data and 1 or 0))
8997

90-
nelement = 1
98+
count = 1
9199
for i = 1,tonumber(c_ndims) do
92-
nelement = nelement * c_dims[i - 1]
100+
count = count * c_dims[i - 1]
93101
end
94-
nelement = tonumber(nelement)
102+
count = tonumber(count)
95103

96104
local atype = dtype or af.dtype.f32
97105
local res = c_array_p()
98106
if not data then
99107
af.clib.af_create_handle(res, c_ndims, c_dims, atype)
100108
else
101-
c_data = ffi.new(af.dtype_names[atype + 1] .. '[?]', nelement, data)
109+
c_data = c_ptr_t(af.dtype_names[atype + 1], count, data)
102110
af.clib.af_create_array(res, c_data, c_ndims, c_dims, atype)
103111
end
104-
self.__arr = add_finalizer(res)
105-
return self
112+
return Array.init(res[0])
106113
end
107114

108115
Array.__tostring = function(self)
109116
return 'arrayfire.Array\n'
110117
end
111118

112119
Array.get = function(self)
113-
return self.__arr
120+
return self._array
121+
end
122+
123+
-- TODO: implement Array.write
124+
125+
Array.copy = function(self)
126+
local res = c_array_p()
127+
af.clib.af_copy_array(res, self._array)
128+
return Array.init(res[0])
129+
end
130+
131+
Array.softCopy = function(self)
132+
local res = c_array_p()
133+
af.clib.af_copy_array(res, self._array)
134+
return Array.init(res[0])
135+
end
136+
137+
Array.elements = function(self)
138+
local res = c_ptr_t('dim_t')
139+
af.clib.af_get_elements(res, self._array)
140+
return tonumber(res[0])
114141
end
115142

143+
Array.type = function(self)
144+
local res = c_ptr_t('af_dtype')
145+
af.clib.af_get_type(res, self._array)
146+
return tonumber(res[0])
147+
end
148+
149+
Array.typeName = function(self)
150+
local res = c_ptr_t('af_dtype')
151+
af.clib.af_get_type(res, self._array)
152+
return af.dtype_names[tonumber(res[0])]
153+
end
154+
155+
Array.dims = function(self)
156+
local res = c_dim4_t()
157+
af.clib.af_get_dims(res + 0, res + 1, res + 2, res + 3, self._array)
158+
return Dim4(tonumber(res[0]), tonumber(res[1]),
159+
tonumber(res[2]), tonumber(res[3]))
160+
end
161+
162+
Array.numdims = function(self)
163+
local res = c_ptr_t('unsigned int')
164+
af.clib.af_get_numdims(res, self._array)
165+
return tonumber(res[0])
166+
end
167+
168+
local funcs = {
169+
isEmpty = 'is_empty',
170+
isScalar = 'is_scalar',
171+
isRow = 'is_row',
172+
isColumn = 'is_column',
173+
isVector = 'is_vector',
174+
isComplex = 'is_complex',
175+
isReal = 'is_real',
176+
isDouble = 'is_double',
177+
isSingle = 'is_single',
178+
isRealFloating = 'is_realfloating',
179+
isFloating = 'is_floating',
180+
isInteger = 'is_integer',
181+
isBool = 'is_bool',
182+
}
183+
184+
for name, cname in pairs(funcs) do
185+
Array[name] = function(self)
186+
local res = c_ptr_t('bool')
187+
af.clib['af_' .. cname](res, self._array)
188+
return res[0]
189+
end
190+
end
191+
192+
Array.eval = function(self)
193+
af.clib.af_eval(self._array)
194+
end
195+
196+
-- Useful aliases
197+
Array.ndims = Array.numdims
198+
Array.nElement = Array.elements
199+
Array.clone = Array.copy
200+
116201
setmetatable(
117202
Array,
118203
{
@@ -124,3 +209,5 @@ setmetatable(
124209

125210
af.Array = Array
126211
af.ffi.add_finalizer = add_finalizer
212+
af.ffi.c_array_p = c_array_p
213+
af.Array.init = init

arrayfire/dim4.lua

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
require('arrayfire.lib')
2+
require('arrayfire.defines')
3+
local ffi = require( "ffi" )
4+
require 'string'
5+
6+
local Dim4 = {}
7+
Dim4.__index = Dim4
8+
9+
setmetatable(
10+
Dim4,
11+
{
12+
__call = function(cls, ...)
13+
return cls.__init(...)
14+
end
15+
}
16+
)
17+
18+
Dim4.__init = function(d1, d2, d3, d4)
19+
local self = setmetatable({d1 or 1, d2 or 1, d3 or 1, d4 or 1}, Dim4)
20+
return self
21+
end
22+
23+
Dim4.__tostring = function(self)
24+
return string.format('[%d, %d, %d, %d]', self[1], self[2], self[3], self[4])
25+
end
26+
27+
Dim4.ndims = function(self)
28+
for i = 4,1,-1 do
29+
if self[i] ~= 1 then
30+
return self[i] == 0 and 0 or i
31+
end
32+
end
33+
return self[1] == 0 and 0 or 1
34+
end
35+
36+
af.Dim4 =Dim4

arrayfire/lib.lua

+4-5
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ af.lib.cdef = function(funcs)
5959
end
6060
end
6161

62-
6362
af.isnumber = function(val)
6463
return type(val) == "number"
6564
end
@@ -74,14 +73,14 @@ af.ffi.c_void_p = function()
7473
return ffi.new('void *')
7574
end
7675

77-
af.ffi.c_array_p = function(ptr)
78-
return ffi.new('void *[1]', ptr)
79-
end
80-
8176
af.ffi.c_dim_t = function(number)
8277
return ffi.new('dim_t', number)
8378
end
8479

80+
af.ffi.c_ptr_t = function(ptr_type, count, values)
81+
return ffi.new(ptr_type .. ' [?]', count or 1, values)
82+
end
83+
8584
af.ffi.c_uint_t = function(number)
8685
return ffi.new('unsigned int', number)
8786
end

arrayfire/util.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ funcs[34] = [[
2929
af.lib.cdef(funcs)
3030

3131
af.print = function(arr)
32-
af.clib.af_print_array_gen(ffi.cast("char *", "ArrayFire Array"), arr:get(), 4)
32+
af.clib.af_print_array_gen(ffi.cast("char *", "ArrayFire Array"), arr._array, 4)
3333
end

rocks/arrayfire-scm-1.rockspec

+5-4
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ build = {
1717
type = "builtin",
1818
modules = {
1919
arrayfire = "arrayfire.lua",
20-
["arrayfire.lib"] = "arrayfire/lib.lua",
21-
["arrayfire.util"] = "arrayfire/util.lua",
22-
["arrayfire.array"] = "arrayfire/array.lua",
20+
["arrayfire.lib"] = "arrayfire/lib.lua",
21+
["arrayfire.util"] = "arrayfire/util.lua",
22+
["arrayfire.array"] = "arrayfire/array.lua",
2323
["arrayfire.defines"] = "arrayfire/defines.lua",
24-
["arrayfire.device"] = "arrayfire/device.lua",
24+
["arrayfire.device"] = "arrayfire/device.lua",
25+
["arrayfire.dim4"] = "arrayfire/dim4.lua",
2526
},
2627
}

0 commit comments

Comments
 (0)