Skip to content
39 changes: 33 additions & 6 deletions gpyfft/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,41 @@ def calculate_transform_strides(cls,

return (tuple(t_strides), t_distance, batchsize, tuple(t_shape), axes_transform)

def enqueue(self, forward = True):
def enqueue(self, forward = True, wait_for_events = None):
return self.enqueue_arrays(forward=forward, data=self.data, result=self.result, wait_for_events=wait_for_events)

def enqueue_arrays(self, data = None, result = None, forward = True, wait_for_events = None):
"""enqueue transform"""
if self.result is not None:
events = self.plan.enqueue_transform((self.queue,), (self.data.data,), (self.result.data),
direction_forward = forward, temp_buffer = self.temp_buffer)
if data is None:
data = self.data
else:
assert data.shape == self.data.shape
assert data.strides == self.data.strides
assert data.dtype == self.data.dtype
if result is None:
result = self.result
else:
events = self.plan.enqueue_transform((self.queue,), (self.data.data,),
direction_forward = forward, temp_buffer = self.temp_buffer)
assert result.shape == self.result.shape
assert result.strides == self.result.strides
assert result.dtype == self.result.dtype

# get buffer for data
if data.offset != 0:
data = data._new_with_changes(data=data.base_data[data.offset:], offset=0)
data_buffer = data.base_data

if result is not None:
# get buffer for result
if result.offset != 0:
result = result._new_with_changes(data=result.base_data[result.offset:], offset=0)
result_buffer = result.base_data

events = self.plan.enqueue_transform((self.queue,), (data_buffer,), (result_buffer),
direction_forward = forward, temp_buffer = self.temp_buffer, wait_for_events = wait_for_events)
else:
events = self.plan.enqueue_transform((self.queue,), (data_buffer,),
direction_forward = forward, temp_buffer = self.temp_buffer, wait_for_events = wait_for_events)

return events

def update_arrays(self, input_array, output_array):
Expand Down