Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resolve a few issues #138

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions mkl_fft/_numpy_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,13 @@ def _check_norm(norm):


def frwd_sc_1d(n, s):
nn = n if n else s
nn = n if n is not None else s
return 1/nn if nn != 0 else 1


def frwd_sc_nd(s, axes, x_shape):
ss = s if s is not None else x_shape
if axes is not None:
nn = prod([ss[ai] for ai in axes])
else:
nn = prod(ss)
nn = prod(ss)
return 1/nn if nn != 0 else 1


Expand Down Expand Up @@ -822,7 +819,7 @@ def fftn(a, s=None, axes=None, norm=None):
return trycall(
mkl_fft.fftn,
(x,),
{'shape': s, 'axes': axes,
{'s': s, 'axes': axes,
'fwd_scale': fsc})


Expand Down Expand Up @@ -938,7 +935,7 @@ def ifftn(a, s=None, axes=None, norm=None):
return trycall(
mkl_fft.ifftn,
(x,),
{'shape': s, 'axes': axes,
{'s': s, 'axes': axes,
'fwd_scale': fsc})


Expand Down
94 changes: 50 additions & 44 deletions mkl_fft/_pydfti.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ cdef int _datacopied(cnp.ndarray arr, object orig):


def fft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=+1, fsc=fwd_scale)
return _fft1d_impl(x, n=n, axis=axis, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)


def ifft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=-1, fsc=fwd_scale)
return _fft1d_impl(x, n=n, axis=axis, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)


cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int realQ):
Expand Down Expand Up @@ -200,7 +200,7 @@ cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int real


cdef cnp.ndarray __process_arguments(object x, object n, object axis,
object overwrite_arg, object direction,
object overwrite_x, object direction,
long *axis_, long *n_, int *in_place,
int *xnd, int *dir_, int realQ):
"Internal utility to validate and process input arguments of 1D FFT functions"
Expand All @@ -213,7 +213,7 @@ cdef cnp.ndarray __process_arguments(object x, object n, object axis,
else:
dir_[0] = -1 if direction is -1 else +1

in_place[0] = 1 if overwrite_arg is True else 0
in_place[0] = 1 if overwrite_x else 0

# convert x to ndarray, ensure that strides are multiples of itemsize
x_arr = PyArray_CheckFromAny(
Expand Down Expand Up @@ -294,7 +294,7 @@ cdef cnp.ndarray __allocate_result(cnp.ndarray x_arr, long n_, long axis_, int f
# Float/double inputs are not cast to complex, but are effectively
# treated as complexes with zero imaginary parts.
# All other types are cast to complex double.
def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fsc=1.0):
def _fft1d_impl(x, n=None, axis=-1, overwrite_x=False, direction=+1, double fsc=1.0):
"""
Uses MKL to perform 1D FFT on the input array x along the given axis.
"""
Expand All @@ -308,7 +308,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs
cdef bytes py_error_msg
cdef DftiCache *_cache

x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
x_arr = __process_arguments(x, n, axis, overwrite_x, direction,
&axis_, &n_, &in_place, &xnd, &dir_, 0)

x_type = cnp.PyArray_TYPE(x_arr)
Expand Down Expand Up @@ -410,12 +410,12 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs

def rfftpack(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
"""Packed real-valued harmonics of FFT of a real sequence x"""
return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x, fsc=fwd_scale)
return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_x=overwrite_x, fsc=fwd_scale)


def irfftpack(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0):
"""Inverse FFT of a real sequence, takes packed real-valued harmonics of FFT"""
return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x, fsc=fwd_scale)
return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_x=overwrite_x, fsc=fwd_scale)


cdef object _rc_to_rr(cnp.ndarray rc_arr, int n, int axis, int xnd, int x_type):
Expand Down Expand Up @@ -520,12 +520,12 @@ def _repack_rc_to_rr(x, n, axis):
return _rc_to_rr(x, n_, axis_, cnp.PyArray_NDIM(x_arr), x_type)


def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
"""
Uses MKL to perform real packed 1D FFT on the input array x along the given axis.

This done by using rfft and post-processing the result.
Thus overwrite_arg is effectively discarded.
Thus overwrite_x is effectively discarded.

Functionally equivalent to scipy.fftpack.rfft
"""
Expand All @@ -539,7 +539,7 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
cdef bytes py_error_msg
cdef DftiCache *_cache

x_arr = __process_arguments(x, n, axis, overwrite_arg, <object>(+1),
x_arr = __process_arguments(x, n, axis, overwrite_x, <object>(+1),
&axis_, &n_, &in_place, &xnd, &dir_, 1)

x_type = cnp.PyArray_TYPE(x_arr)
Expand Down Expand Up @@ -576,12 +576,12 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
return _rc_to_rr(f_arr, n_, axis_, xnd, x_type)


def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
"""
Uses MKL to perform real packed 1D FFT on the input array x along the given axis.

This done by using rfft and post-processing the result.
Thus overwrite_arg is effectively discarded.
Thus overwrite_x is effectively discarded.

Functionally equivalent to scipy.fftpack.irfft
"""
Expand All @@ -595,7 +595,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
cdef bytes py_error_msg
cdef DftiCache *_cache

x_arr = __process_arguments(x, n, axis, overwrite_arg, <object>(-1),
x_arr = __process_arguments(x, n, axis, overwrite_x, <object>(-1),
&axis_, &n_, &in_place, &xnd, &dir_, 1)

x_type = cnp.PyArray_TYPE(x_arr)
Expand Down Expand Up @@ -645,7 +645,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):


# this routine is functionally equivalent to numpy.fft.rfft
def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
"""
Uses MKL to perform 1D FFT on the real input array x along the given axis,
producing complex output, but giving only half of the harmonics.
Expand All @@ -663,13 +663,13 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
cdef bytes py_error_msg
cdef DftiCache *_cache

x_arr = __process_arguments(x, n, axis, overwrite_arg, direction,
x_arr = __process_arguments(x, n, axis, overwrite_x, direction,
&axis_, &n_, &in_place, &xnd, &dir_, 1)

x_type = cnp.PyArray_TYPE(x_arr)

if x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE or x_type is cnp.NPY_CLONGDOUBLE:
raise TypeError("1st argument must be a real sequence 1")
raise TypeError("1st argument must be a real sequence.")
elif x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_DOUBLE:
pass
else:
Expand Down Expand Up @@ -723,7 +723,7 @@ cdef int _is_integral(object num):


# this routine is functionally equivalent to numpy.fft.irfft
def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0):
"""
Uses MKL to perform 1D FFT on the real input array x along the given axis,
producing complex output, but giving only half of the harmonics.
Expand All @@ -743,7 +743,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0):
int_n = _is_integral(n)
# nn gives the number elements along axis of the input that we use
nn = (n // 2 + 1) if int_n and n > 0 else n
x_arr = __process_arguments(x, nn, axis, overwrite_arg, direction,
x_arr = __process_arguments(x, nn, axis, overwrite_x, direction,
&axis_, &n_, &in_place, &xnd, &dir_, 0)
n_ = 2*(n_ - 1)
if int_n and (n % 2 == 1):
Expand Down Expand Up @@ -907,10 +907,10 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0):
return s, axes


def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False, scale_function=lambda n, ind: 1.0):
def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_x=False, scale_function=lambda n, ind: 1.0):
a = np.asarray(a)
s, axes = _init_nd_shape_and_axes(a, s, axes)
ovwr = overwrite_arg
ovwr = overwrite_x
for ii in reversed(range(len(axes))):
a = function(a, n = s[ii], axis = axes[ii], overwrite_x=ovwr, fwd_scale=scale_function(s[ii], ii))
ovwr = True
Expand Down Expand Up @@ -959,7 +959,7 @@ def iter_complementary(x, axes, func, kwargs, result):
return result


def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
def _direct_fftnd(x, overwrite_x=False, direction=+1, double fsc=1.0):
"""Perform n-dimensional FFT over all axes"""
cdef int err
cdef long n_max = 0
Expand All @@ -972,7 +972,7 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
else:
dir_ = -1 if direction is -1 else +1

in_place = 1 if overwrite_arg is True else 0
in_place = 1 if overwrite_x else 0

# convert x to ndarray, ensure that strides are multiples of itemsize
x_arr = PyArray_CheckFromAny(
Expand Down Expand Up @@ -1069,56 +1069,56 @@ def _output_dtype(dt):
return dt


def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0):
def _fftnd_impl(x, s=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0):
if direction not in [-1, +1]:
raise ValueError("Direction of FFT should +1 or -1")

# _direct_fftnd requires complex type, and full-dimensional transform
if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
_direct = shape is None and axes is None
_direct = s is None and axes is None
if _direct:
_direct = x.ndim <= 7 # Intel MKL only supports FFT up to 7D
if not _direct:
xs, xa = _cook_nd_args(x, shape, axes)
xs, xa = _cook_nd_args(x, s, axes)
if _check_shapes_for_direct(xs, x.shape, xa):
_direct = True
_direct = _direct and x.dtype in [np.complex64, np.complex128, np.float32, np.float64]
else:
_direct = False

if _direct:
return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction, fsc=fsc)
return _direct_fftnd(x, overwrite_x=overwrite_x, direction=direction, fsc=fsc)
else:
if (shape is None and x.dtype in [np.csingle, np.cdouble, np.single, np.double]):
if (s is None and x.dtype in [np.csingle, np.cdouble, np.single, np.double]):
x = np.asarray(x)
res = np.empty(x.shape, dtype=_output_dtype(x.dtype))
return iter_complementary(
x, axes,
_direct_fftnd,
{'overwrite_arg': overwrite_x, 'direction': direction, 'fsc': fsc},
{'overwrite_x': overwrite_x, 'direction': direction, 'fsc': fsc},
res
)
else:
sc = <object> fsc
return _iter_fftnd(x, s=shape, axes=axes,
overwrite_arg=overwrite_x, scale_function=lambda n, i: sc if i == 0 else 1.,
return _iter_fftnd(x, s=s, axes=axes,
overwrite_x=overwrite_x, scale_function=lambda n, i: sc if i == 0 else 1.,
function=fft if direction == 1 else ifft)


def fft2(x, shape=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
def fft2(x, s=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)


def ifft2(x, shape=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
def ifft2(x, s=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0):
return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)


def fftn(x, shape=None, axes=None, overwrite_x=False, fwd_scale=1.0):
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)
def fftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0):
return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale)


def ifftn(x, shape=None, axes=None, overwrite_x=False, fwd_scale=1.0):
return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)
def ifftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0):
return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale)


def rfft2(x, s=None, axes=(-2,-1), fwd_scale=1.0):
Expand Down Expand Up @@ -1154,7 +1154,7 @@ cdef cnp.ndarray _trim_array(cnp.ndarray arr, object s, object axes):
raise ValueError("Invalid axis (%d) specified" % ai)
if si < shp_i:
if no_trim:
ind = [slice(None,None,None),] * len(s)
ind = [slice(None,None,None),] * len(arr_shape)
no_trim = False
ind[ai] = slice(None, si, None)
if no_trim:
Expand Down Expand Up @@ -1203,12 +1203,12 @@ def rfftn(x, s=None, axes=None, fwd_scale=1.0):
tind = tuple(ind)
a_inp = a[tind]
a_res = _fftnd_impl(
a_inp, shape=ss, axes=aa,
a_inp, s=ss, axes=aa,
overwrite_x=True, direction=1)
if a_res is not a_inp:
a[tind] = a_res # copy in place
else:
for ii in range(len(axes)-1):
for ii in range(len(axes)-2, -1, -1):
a = fft(a, s[ii], axes[ii], overwrite_x=True)
return a

Expand All @@ -1218,6 +1218,8 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0):
no_trim = (s is None) and (axes is None)
s, axes = _cook_nd_args(a, s, axes, invreal=True)
la = axes[-1]
if not no_trim:
a = _trim_array(a, s, axes)
if len(s) > 1:
if not no_trim:
a = _fix_dimensions(a, s, axes)
Expand All @@ -1227,14 +1229,18 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0):
if not ovr_x:
a = a.copy()
ovr_x = True
if not np.issubdtype(a.dtype, np.complexfloating):
# copy is needed, because output of complex type will be copied to input
a = a.astype(np.complex64) if a.dtype == np.float32 else a.astype(np.complex128)
ovr_x = True
ss, aa = _remove_axis(s, axes, -1)
ind = [slice(None,None,1),] * len(s)
ind = [slice(None, None, 1),] * len(s)
for ii in range(a.shape[la]):
ind[la] = ii
tind = tuple(ind)
a_inp = a[tind]
a_res = _fftnd_impl(
a_inp, shape=ss, axes=aa,
a_inp, s=ss, axes=aa,
overwrite_x=True, direction=-1)
if a_res is not a_inp:
a[tind] = a_res # copy in place
Expand Down
8 changes: 4 additions & 4 deletions mkl_fft/_scipy_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None, pl
fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape)
_check_plan(plan)
with Workers(workers):
output = mkl_fft.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
output = mkl_fft.fftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
return output


Expand All @@ -293,7 +293,7 @@ def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None, p
fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape)
_check_plan(plan)
with Workers(workers):
output = mkl_fft.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
output = mkl_fft.ifftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
return output


Expand All @@ -307,7 +307,7 @@ def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None, plan=
fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape)
_check_plan(plan)
with Workers(workers):
output = mkl_fft.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
output = mkl_fft.fftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
return output


Expand All @@ -321,7 +321,7 @@ def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None, plan
fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape)
_check_plan(plan)
with Workers(workers):
output = mkl_fft.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
output = mkl_fft.ifftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc)
return output


Expand Down
Loading
Loading