Skip to content

Commit

Permalink
enable transform (OpenCL) callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
geggo committed Jan 26, 2017
1 parent 823fa32 commit 9b4f777
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion gpyfft/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import numpy as np

# TODO:
# real to complex: out-of-place

class FFT(object):
def __init__(self, context, queue, in_array, out_array=None, axes = None,
fast_math = False,
real=False,
callbacks=None, #dict: 'pre', 'post'
):
# Callbacks: dict(pre=b'pre source (kernel named pre!)')
self.context = context
self.queue = queue

Expand Down Expand Up @@ -77,6 +78,16 @@ def __init__(self, context, queue, in_array, out_array=None, axes = None,
plan.precision = precision
plan.layouts = (layout_in, layout_out)

if callbacks is not None:
if callbacks.has_key('pre'):
plan.set_callback(b'pre',
callbacks['pre'],
'pre')
if 'post' in callbacks:
plan.set_callback(b'post',
callbacks['post'],
'post')

if False:
print('axes', axes )
print('in_array.shape: ', in_array.shape)
Expand Down

0 comments on commit 9b4f777

Please sign in to comment.