Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/geggo/gpyfft
Browse files Browse the repository at this point in the history
* 'master' of https://github.com/geggo/gpyfft:
  Merge PR #29: Accept pyopencl arrays with nonzero offsets (PR #29), add enqueue_arrays method
  Allow buffers to be of PyOpenCL-type PooledBuffer (in addition to the standard Buffer)
  Don't send an empty event array (clFFT thinks the sky is falling)
  • Loading branch information
geggo committed Jan 26, 2017
2 parents 9b4f777 + 7a8bc41 commit 3952181
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
39 changes: 33 additions & 6 deletions gpyfft/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,41 @@ def calculate_transform_strides(cls,

return (tuple(t_strides), t_distance, batchsize, tuple(t_shape), axes_transform)

def enqueue(self, forward = True):
def enqueue(self, forward = True, wait_for_events = None):
return self.enqueue_arrays(forward=forward, data=self.data, result=self.result, wait_for_events=wait_for_events)

def enqueue_arrays(self, data = None, result = None, forward = True, wait_for_events = None):
"""enqueue transform"""
if self.result is not None:
events = self.plan.enqueue_transform((self.queue,), (self.data.data,), (self.result.data),
direction_forward = forward, temp_buffer = self.temp_buffer)
if data is None:
data = self.data
else:
assert data.shape == self.data.shape
assert data.strides == self.data.strides
assert data.dtype == self.data.dtype
if result is None:
result = self.result
else:
events = self.plan.enqueue_transform((self.queue,), (self.data.data,),
direction_forward = forward, temp_buffer = self.temp_buffer)
assert result.shape == self.result.shape
assert result.strides == self.result.strides
assert result.dtype == self.result.dtype

# get buffer for data
if data.offset != 0:
data = data._new_with_changes(data=data.base_data[data.offset:], offset=0)
data_buffer = data.base_data

if result is not None:
# get buffer for result
if result.offset != 0:
result = result._new_with_changes(data=result.base_data[result.offset:], offset=0)
result_buffer = result.base_data

events = self.plan.enqueue_transform((self.queue,), (data_buffer,), (result_buffer),
direction_forward = forward, temp_buffer = self.temp_buffer, wait_for_events = wait_for_events)
else:
events = self.plan.enqueue_transform((self.queue,), (data_buffer,),
direction_forward = forward, temp_buffer = self.temp_buffer, wait_for_events = wait_for_events)

return events

def update_arrays(self, input_array, output_array):
Expand Down
12 changes: 6 additions & 6 deletions gpyfft/gpyfftlib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ cdef class Plan(object):
cdef cl_event wait_for_events_array[MAX_WAITFOR_EVENTS]
cdef cl_event* wait_for_events_ = NULL
cdef n_waitfor_events = 0
if wait_for_events is not None:
if wait_for_events is not None and len(wait_for_events) > 0:
n_waitfor_events = len(wait_for_events)
assert n_waitfor_events <= MAX_WAITFOR_EVENTS
for i, event in enumerate(wait_for_events):
Expand All @@ -559,29 +559,29 @@ cdef class Plan(object):
wait_for_events_ = &wait_for_events_array[0]

cdef cl_mem in_buffers_[2]
if isinstance(in_buffers, cl.Buffer):
if isinstance(in_buffers, cl.MemoryObjectHolder):
in_buffers = (in_buffers,)
n_in_buffers = len(in_buffers)
assert n_in_buffers <= 2
for i, in_buffer in enumerate(in_buffers):
assert isinstance(in_buffer, cl.Buffer)
assert isinstance(in_buffer, cl.MemoryObjectHolder)
in_buffers_[i] = <cl_mem><voidptr_t>in_buffer.int_ptr

cdef cl_mem out_buffers_array[2]
cdef cl_mem* out_buffers_ = NULL
if out_buffers is not None:
if isinstance(out_buffers, cl.Buffer):
if isinstance(out_buffers, cl.MemoryObjectHolder):
out_buffers = (out_buffers,)
n_out_buffers = len(out_buffers)
assert n_out_buffers in (1,2)
for i, out_buffer in enumerate(out_buffers):
assert isinstance(out_buffer, cl.Buffer)
assert isinstance(out_buffer, cl.MemoryObjectHolder)
out_buffers_array[i] = <cl_mem><voidptr_t>out_buffer.int_ptr
out_buffers_ = &out_buffers_array[0]

cdef cl_mem tmp_buffer_ = NULL
if temp_buffer is not None:
assert isinstance(temp_buffer, cl.Buffer)
assert isinstance(temp_buffer, cl.MemoryObjectHolder)
tmp_buffer_ = <cl_mem><voidptr_t>temp_buffer.int_ptr

cdef cl_event out_cl_events[MAX_QUEUES]
Expand Down

0 comments on commit 3952181

Please sign in to comment.