diff --git a/gpyfft/fft.py b/gpyfft/fft.py index a1da7db..d005bc1 100644 --- a/gpyfft/fft.py +++ b/gpyfft/fft.py @@ -148,14 +148,41 @@ def calculate_transform_strides(cls, return (tuple(t_strides), t_distance, batchsize, tuple(t_shape), axes_transform) - def enqueue(self, forward = True): + def enqueue(self, forward = True, wait_for_events = None): + return self.enqueue_arrays(forward=forward, data=self.data, result=self.result, wait_for_events=wait_for_events) + + def enqueue_arrays(self, data = None, result = None, forward = True, wait_for_events = None): """enqueue transform""" - if self.result is not None: - events = self.plan.enqueue_transform((self.queue,), (self.data.data,), (self.result.data), - direction_forward = forward, temp_buffer = self.temp_buffer) + if data is None: + data = self.data + else: + assert data.shape == self.data.shape + assert data.strides == self.data.strides + assert data.dtype == self.data.dtype + if result is None: + result = self.result else: - events = self.plan.enqueue_transform((self.queue,), (self.data.data,), - direction_forward = forward, temp_buffer = self.temp_buffer) + assert result.shape == self.result.shape + assert result.strides == self.result.strides + assert result.dtype == self.result.dtype + + # get buffer for data + if data.offset != 0: + data = data._new_with_changes(data=data.base_data[data.offset:], offset=0) + data_buffer = data.base_data + + if result is not None: + # get buffer for result + if result.offset != 0: + result = result._new_with_changes(data=result.base_data[result.offset:], offset=0) + result_buffer = result.base_data + + events = self.plan.enqueue_transform((self.queue,), (data_buffer,), (result_buffer), + direction_forward = forward, temp_buffer = self.temp_buffer, wait_for_events = wait_for_events) + else: + events = self.plan.enqueue_transform((self.queue,), (data_buffer,), + direction_forward = forward, temp_buffer = self.temp_buffer, wait_for_events = wait_for_events) + return events def update_arrays(self, input_array, output_array): diff --git a/gpyfft/gpyfftlib.pyx b/gpyfft/gpyfftlib.pyx index f2fa2e6..2d5950d 100644 --- a/gpyfft/gpyfftlib.pyx +++ b/gpyfft/gpyfftlib.pyx @@ -550,7 +550,7 @@ cdef class Plan(object): cdef cl_event wait_for_events_array[MAX_WAITFOR_EVENTS] cdef cl_event* wait_for_events_ = NULL cdef n_waitfor_events = 0 - if wait_for_events is not None: + if wait_for_events is not None and len(wait_for_events) > 0: n_waitfor_events = len(wait_for_events) assert n_waitfor_events <= MAX_WAITFOR_EVENTS for i, event in enumerate(wait_for_events): @@ -559,29 +559,29 @@ cdef class Plan(object): wait_for_events_ = &wait_for_events_array[0] cdef cl_mem in_buffers_[2] - if isinstance(in_buffers, cl.Buffer): + if isinstance(in_buffers, cl.MemoryObjectHolder): in_buffers = (in_buffers,) n_in_buffers = len(in_buffers) assert n_in_buffers <= 2 for i, in_buffer in enumerate(in_buffers): - assert isinstance(in_buffer, cl.Buffer) + assert isinstance(in_buffer, cl.MemoryObjectHolder) in_buffers_[i] = in_buffer.int_ptr cdef cl_mem out_buffers_array[2] cdef cl_mem* out_buffers_ = NULL if out_buffers is not None: - if isinstance(out_buffers, cl.Buffer): + if isinstance(out_buffers, cl.MemoryObjectHolder): out_buffers = (out_buffers,) n_out_buffers = len(out_buffers) assert n_out_buffers in (1,2) for i, out_buffer in enumerate(out_buffers): - assert isinstance(out_buffer, cl.Buffer) + assert isinstance(out_buffer, cl.MemoryObjectHolder) out_buffers_array[i] = out_buffer.int_ptr out_buffers_ = &out_buffers_array[0] cdef cl_mem tmp_buffer_ = NULL if temp_buffer is not None: - assert isinstance(temp_buffer, cl.Buffer) + assert isinstance(temp_buffer, cl.MemoryObjectHolder) tmp_buffer_ = temp_buffer.int_ptr cdef cl_event out_cl_events[MAX_QUEUES]