diff --git a/mkl_fft/_numpy_fft.py b/mkl_fft/_numpy_fft.py index 2c832d5..0a0a07d 100644 --- a/mkl_fft/_numpy_fft.py +++ b/mkl_fft/_numpy_fft.py @@ -71,16 +71,13 @@ def _check_norm(norm): def frwd_sc_1d(n, s): - nn = n if n else s + nn = n if n is not None else s return 1/nn if nn != 0 else 1 -def frwd_sc_nd(s, axes, x_shape): +def frwd_sc_nd(s, x_shape): ss = s if s is not None else x_shape - if axes is not None: - nn = prod([ss[ai] for ai in axes]) - else: - nn = prod(ss) + nn = prod(ss) return 1/nn if nn != 0 else 1 @@ -815,14 +812,14 @@ def fftn(a, s=None, axes=None, norm=None): if norm in (None, "backward"): fsc = 1.0 elif norm == "forward": - fsc = frwd_sc_nd(s, axes, x.shape) + fsc = frwd_sc_nd(s, x.shape) else: - fsc = sqrt(frwd_sc_nd(s, axes, x.shape)) + fsc = sqrt(frwd_sc_nd(s, x.shape)) return trycall( mkl_fft.fftn, (x,), - {'shape': s, 'axes': axes, + {'s': s, 'axes': axes, 'fwd_scale': fsc}) @@ -931,14 +928,14 @@ def ifftn(a, s=None, axes=None, norm=None): if norm in (None, "backward"): fsc = 1.0 elif norm == "forward": - fsc = frwd_sc_nd(s, axes, x.shape) + fsc = frwd_sc_nd(s, x.shape) else: - fsc = sqrt(frwd_sc_nd(s, axes, x.shape)) + fsc = sqrt(frwd_sc_nd(s, x.shape)) return trycall( mkl_fft.ifftn, (x,), - {'shape': s, 'axes': axes, + {'s': s, 'axes': axes, 'fwd_scale': fsc}) @@ -1230,11 +1227,11 @@ def rfftn(a, s=None, axes=None, norm=None): elif norm == "forward": x = asanyarray(x) s, axes = _cook_nd_args(x, s, axes) - fsc = frwd_sc_nd(s, axes, x.shape) + fsc = frwd_sc_nd(s, x.shape) else: x = asanyarray(x) s, axes = _cook_nd_args(x, s, axes) - fsc = sqrt(frwd_sc_nd(s, axes, x.shape)) + fsc = sqrt(frwd_sc_nd(s, x.shape)) return trycall( mkl_fft.rfftn, @@ -1387,11 +1384,11 @@ def irfftn(a, s=None, axes=None, norm=None): elif norm == "forward": x = asanyarray(x) s, axes = _cook_nd_args(x, s, axes, invreal=1) - fsc = frwd_sc_nd(s, axes, x.shape) + fsc = frwd_sc_nd(s, x.shape) else: x = asanyarray(x) s, axes = _cook_nd_args(x, s, axes, invreal=1) - fsc = sqrt(frwd_sc_nd(s, axes, x.shape)) + fsc = sqrt(frwd_sc_nd(s, x.shape)) return trycall( mkl_fft.irfftn, diff --git a/mkl_fft/_pydfti.pyx b/mkl_fft/_pydfti.pyx index 47e0bfd..ded8ce1 100644 --- a/mkl_fft/_pydfti.pyx +++ b/mkl_fft/_pydfti.pyx @@ -157,11 +157,11 @@ cdef int _datacopied(cnp.ndarray arr, object orig): def fft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0): - return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=+1, fsc=fwd_scale) + return _fft1d_impl(x, n=n, axis=axis, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale) def ifft(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0): - return _fft1d_impl(x, n=n, axis=axis, overwrite_arg=overwrite_x, direction=-1, fsc=fwd_scale) + return _fft1d_impl(x, n=n, axis=axis, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale) cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int realQ): @@ -200,7 +200,7 @@ cdef cnp.ndarray pad_array(cnp.ndarray x_arr, cnp.npy_intp n, int axis, int real cdef cnp.ndarray __process_arguments(object x, object n, object axis, - object overwrite_arg, object direction, + object overwrite_x, object direction, long *axis_, long *n_, int *in_place, int *xnd, int *dir_, int realQ): "Internal utility to validate and process input arguments of 1D FFT functions" @@ -213,7 +213,7 @@ cdef cnp.ndarray __process_arguments(object x, object n, object axis, else: dir_[0] = -1 if direction is -1 else +1 - in_place[0] = 1 if overwrite_arg is True else 0 + in_place[0] = 1 if overwrite_x else 0 # convert x to ndarray, ensure that strides are multiples of itemsize x_arr = PyArray_CheckFromAny( @@ -294,7 +294,7 @@ cdef cnp.ndarray __allocate_result(cnp.ndarray x_arr, long n_, long axis_, int f # Float/double inputs are not cast to complex, but are effectively # treated as complexes with zero imaginary parts. # All other types are cast to complex double. -def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fsc=1.0): +def _fft1d_impl(x, n=None, axis=-1, overwrite_x=False, direction=+1, double fsc=1.0): """ Uses MKL to perform 1D FFT on the input array x along the given axis. """ @@ -308,7 +308,7 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs cdef bytes py_error_msg cdef DftiCache *_cache - x_arr = __process_arguments(x, n, axis, overwrite_arg, direction, + x_arr = __process_arguments(x, n, axis, overwrite_x, direction, &axis_, &n_, &in_place, &xnd, &dir_, 0) x_type = cnp.PyArray_TYPE(x_arr) @@ -410,12 +410,12 @@ def _fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, direction=+1, double fs def rfftpack(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0): """Packed real-valued harmonics of FFT of a real sequence x""" - return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x, fsc=fwd_scale) + return _rr_fft1d_impl2(x, n=n, axis=axis, overwrite_x=overwrite_x, fsc=fwd_scale) def irfftpack(x, n=None, axis=-1, overwrite_x=False, fwd_scale=1.0): """Inverse FFT of a real sequence, takes packed real-valued harmonics of FFT""" - return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_arg=overwrite_x, fsc=fwd_scale) + return _rr_ifft1d_impl2(x, n=n, axis=axis, overwrite_x=overwrite_x, fsc=fwd_scale) cdef object _rc_to_rr(cnp.ndarray rc_arr, int n, int axis, int xnd, int x_type): @@ -520,12 +520,12 @@ def _repack_rc_to_rr(x, n, axis): return _rc_to_rr(x, n_, axis_, cnp.PyArray_NDIM(x_arr), x_type) -def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0): +def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0): """ Uses MKL to perform real packed 1D FFT on the input array x along the given axis. This done by using rfft and post-processing the result. - Thus overwrite_arg is effectively discarded. + Thus overwrite_x is effectively discarded. Functionally equivalent to scipy.fftpack.rfft """ @@ -539,7 +539,7 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0): cdef bytes py_error_msg cdef DftiCache *_cache - x_arr = __process_arguments(x, n, axis, overwrite_arg, (+1), + x_arr = __process_arguments(x, n, axis, overwrite_x, (+1), &axis_, &n_, &in_place, &xnd, &dir_, 1) x_type = cnp.PyArray_TYPE(x_arr) @@ -576,12 +576,12 @@ def _rr_fft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0): return _rc_to_rr(f_arr, n_, axis_, xnd, x_type) -def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0): +def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0): """ Uses MKL to perform real packed 1D FFT on the input array x along the given axis. This done by using rfft and post-processing the result. - Thus overwrite_arg is effectively discarded. + Thus overwrite_x is effectively discarded. Functionally equivalent to scipy.fftpack.irfft """ @@ -595,7 +595,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0): cdef bytes py_error_msg cdef DftiCache *_cache - x_arr = __process_arguments(x, n, axis, overwrite_arg, (-1), + x_arr = __process_arguments(x, n, axis, overwrite_x, (-1), &axis_, &n_, &in_place, &xnd, &dir_, 1) x_type = cnp.PyArray_TYPE(x_arr) @@ -645,7 +645,7 @@ def _rr_ifft1d_impl2(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0): # this routine is functionally equivalent to numpy.fft.rfft -def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0): +def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0): """ Uses MKL to perform 1D FFT on the real input array x along the given axis, producing complex output, but giving only half of the harmonics. @@ -663,13 +663,13 @@ def _rc_fft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0): cdef bytes py_error_msg cdef DftiCache *_cache - x_arr = __process_arguments(x, n, axis, overwrite_arg, direction, + x_arr = __process_arguments(x, n, axis, overwrite_x, direction, &axis_, &n_, &in_place, &xnd, &dir_, 1) x_type = cnp.PyArray_TYPE(x_arr) if x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE or x_type is cnp.NPY_CLONGDOUBLE: - raise TypeError("1st argument must be a real sequence 1") + raise TypeError("1st argument must be a real sequence.") elif x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_DOUBLE: pass else: @@ -723,7 +723,7 @@ cdef int _is_integral(object num): # this routine is functionally equivalent to numpy.fft.irfft -def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0): +def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_x=False, double fsc=1.0): """ Uses MKL to perform 1D FFT on the real input array x along the given axis, producing complex output, but giving only half of the harmonics. @@ -743,7 +743,7 @@ def _rc_ifft1d_impl(x, n=None, axis=-1, overwrite_arg=False, double fsc=1.0): int_n = _is_integral(n) # nn gives the number elements along axis of the input that we use nn = (n // 2 + 1) if int_n and n > 0 else n - x_arr = __process_arguments(x, nn, axis, overwrite_arg, direction, + x_arr = __process_arguments(x, nn, axis, overwrite_x, direction, &axis_, &n_, &in_place, &xnd, &dir_, 0) n_ = 2*(n_ - 1) if int_n and (n % 2 == 1): @@ -907,10 +907,10 @@ def _cook_nd_args(a, s=None, axes=None, invreal=0): return s, axes -def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False, scale_function=lambda n, ind: 1.0): +def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_x=False, scale_function=lambda n, ind: 1.0): a = np.asarray(a) s, axes = _init_nd_shape_and_axes(a, s, axes) - ovwr = overwrite_arg + ovwr = overwrite_x for ii in reversed(range(len(axes))): a = function(a, n = s[ii], axis = axes[ii], overwrite_x=ovwr, fwd_scale=scale_function(s[ii], ii)) ovwr = True @@ -959,7 +959,7 @@ def iter_complementary(x, axes, func, kwargs, result): return result -def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0): +def _direct_fftnd(x, overwrite_x=False, direction=+1, double fsc=1.0): """Perform n-dimensional FFT over all axes""" cdef int err cdef long n_max = 0 @@ -972,7 +972,7 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0): else: dir_ = -1 if direction is -1 else +1 - in_place = 1 if overwrite_arg is True else 0 + in_place = 1 if overwrite_x else 0 # convert x to ndarray, ensure that strides are multiples of itemsize x_arr = PyArray_CheckFromAny( @@ -1069,17 +1069,17 @@ def _output_dtype(dt): return dt -def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0): +def _fftnd_impl(x, s=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0): if direction not in [-1, +1]: raise ValueError("Direction of FFT should +1 or -1") # _direct_fftnd requires complex type, and full-dimensional transform if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1: - _direct = shape is None and axes is None + _direct = s is None and axes is None if _direct: _direct = x.ndim <= 7 # Intel MKL only supports FFT up to 7D if not _direct: - xs, xa = _cook_nd_args(x, shape, axes) + xs, xa = _cook_nd_args(x, s, axes) if _check_shapes_for_direct(xs, x.shape, xa): _direct = True _direct = _direct and x.dtype in [np.complex64, np.complex128, np.float32, np.float64] @@ -1087,38 +1087,38 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, doubl _direct = False if _direct: - return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction, fsc=fsc) + return _direct_fftnd(x, overwrite_x=overwrite_x, direction=direction, fsc=fsc) else: - if (shape is None and x.dtype in [np.csingle, np.cdouble, np.single, np.double]): + if (s is None and x.dtype in [np.csingle, np.cdouble, np.single, np.double]): x = np.asarray(x) res = np.empty(x.shape, dtype=_output_dtype(x.dtype)) return iter_complementary( x, axes, _direct_fftnd, - {'overwrite_arg': overwrite_x, 'direction': direction, 'fsc': fsc}, + {'overwrite_x': overwrite_x, 'direction': direction, 'fsc': fsc}, res ) else: sc = fsc - return _iter_fftnd(x, s=shape, axes=axes, - overwrite_arg=overwrite_x, scale_function=lambda n, i: sc if i == 0 else 1., + return _iter_fftnd(x, s=s, axes=axes, + overwrite_x=overwrite_x, scale_function=lambda n, i: sc if i == 0 else 1., function=fft if direction == 1 else ifft) -def fft2(x, shape=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0): - return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale) +def fft2(x, s=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0): + return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale) -def ifft2(x, shape=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0): - return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale) +def ifft2(x, s=None, axes=(-2,-1), overwrite_x=False, fwd_scale=1.0): + return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale) -def fftn(x, shape=None, axes=None, overwrite_x=False, fwd_scale=1.0): - return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale) +def fftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0): + return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=+1, fsc=fwd_scale) -def ifftn(x, shape=None, axes=None, overwrite_x=False, fwd_scale=1.0): - return _fftnd_impl(x, shape=shape, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale) +def ifftn(x, s=None, axes=None, overwrite_x=False, fwd_scale=1.0): + return _fftnd_impl(x, s=s, axes=axes, overwrite_x=overwrite_x, direction=-1, fsc=fwd_scale) def rfft2(x, s=None, axes=(-2,-1), fwd_scale=1.0): @@ -1154,7 +1154,7 @@ cdef cnp.ndarray _trim_array(cnp.ndarray arr, object s, object axes): raise ValueError("Invalid axis (%d) specified" % ai) if si < shp_i: if no_trim: - ind = [slice(None,None,None),] * len(s) + ind = [slice(None,None,None),] * len(arr_shape) no_trim = False ind[ai] = slice(None, si, None) if no_trim: @@ -1203,12 +1203,12 @@ def rfftn(x, s=None, axes=None, fwd_scale=1.0): tind = tuple(ind) a_inp = a[tind] a_res = _fftnd_impl( - a_inp, shape=ss, axes=aa, + a_inp, s=ss, axes=aa, overwrite_x=True, direction=1) if a_res is not a_inp: a[tind] = a_res # copy in place else: - for ii in range(len(axes)-1): + for ii in range(len(axes) - 2, -1, -1): a = fft(a, s[ii], axes[ii], overwrite_x=True) return a @@ -1218,6 +1218,8 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0): no_trim = (s is None) and (axes is None) s, axes = _cook_nd_args(a, s, axes, invreal=True) la = axes[-1] + if not no_trim: + a = _trim_array(a, s, axes) if len(s) > 1: if not no_trim: a = _fix_dimensions(a, s, axes) @@ -1227,14 +1229,18 @@ def irfftn(x, s=None, axes=None, fwd_scale=1.0): if not ovr_x: a = a.copy() ovr_x = True + if not np.issubdtype(a.dtype, np.complexfloating): + # copy is needed, because output of complex type will be copied to input + a = a.astype(np.complex64) if a.dtype == np.float32 else a.astype(np.complex128) + ovr_x = True ss, aa = _remove_axis(s, axes, -1) - ind = [slice(None,None,1),] * len(s) + ind = [slice(None, None, 1),] * len(s) for ii in range(a.shape[la]): ind[la] = ii tind = tuple(ind) a_inp = a[tind] a_res = _fftnd_impl( - a_inp, shape=ss, axes=aa, + a_inp, s=ss, axes=aa, overwrite_x=True, direction=-1) if a_res is not a_inp: a[tind] = a_res # copy in place diff --git a/mkl_fft/_scipy_fft.py b/mkl_fft/_scipy_fft.py index 4d3c9ac..a9dd98b 100644 --- a/mkl_fft/_scipy_fft.py +++ b/mkl_fft/_scipy_fft.py @@ -200,16 +200,13 @@ def _check_plan(plan): def _frwd_sc_1d(n, s): - nn = n if n else s + nn = n if n is not None else s return 1/nn if nn != 0 else 1 -def _frwd_sc_nd(s, axes, x_shape): +def _frwd_sc_nd(s, x_shape): ss = s if s is not None else x_shape - if axes is not None: - nn = prod([ss[ai] for ai in axes]) - else: - nn = prod(ss) + nn = prod(ss) return 1/nn if nn != 0 else 1 @@ -233,9 +230,9 @@ def _compute_nd_fwd_scale(norm, s, axes, x_shape): if norm in (None, "backward"): fsc = 1.0 elif norm == "forward": - fsc = _frwd_sc_nd(s, axes, x_shape) + fsc = _frwd_sc_nd(s, x_shape) elif norm == "ortho": - fsc = sqrt(_frwd_sc_nd(s, axes, x_shape)) + fsc = sqrt(_frwd_sc_nd(s, x_shape)) else: _check_norm(norm) return fsc @@ -279,7 +276,7 @@ def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None, pl fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape) _check_plan(plan) with Workers(workers): - output = mkl_fft.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc) + output = mkl_fft.fftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc) return output @@ -293,7 +290,7 @@ def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None, p fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape) _check_plan(plan) with Workers(workers): - output = mkl_fft.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc) + output = mkl_fft.ifftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc) return output @@ -307,7 +304,7 @@ def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None, plan= fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape) _check_plan(plan) with Workers(workers): - output = mkl_fft.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc) + output = mkl_fft.fftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc) return output @@ -321,7 +318,7 @@ def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None, plan fsc = _compute_nd_fwd_scale(norm, s, axes, x.shape) _check_plan(plan) with Workers(workers): - output = mkl_fft.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc) + output = mkl_fft.ifftn(x, s=s, axes=axes, overwrite_x=overwrite_x, fwd_scale=fsc) return output @@ -359,10 +356,10 @@ def _compute_nd_fwd_scale_for_rfft(norm, s, axes, x, invreal=False): fsc = 1.0 elif norm == "forward": s, axes = _cook_nd_args(x, s, axes, invreal=invreal) - fsc = _frwd_sc_nd(s, axes, x.shape) + fsc = _frwd_sc_nd(s, x.shape) elif norm == "ortho": s, axes = _cook_nd_args(x, s, axes, invreal=invreal) - fsc = sqrt(_frwd_sc_nd(s, axes, x.shape)) + fsc = sqrt(_frwd_sc_nd(s, x.shape)) else: _check_norm(norm) return s, axes, fsc diff --git a/mkl_fft/tests/test_fftnd.py b/mkl_fft/tests/test_fftnd.py index 20c73e4..da02abd 100644 --- a/mkl_fft/tests/test_fftnd.py +++ b/mkl_fft/tests/test_fftnd.py @@ -31,7 +31,7 @@ from numpy import random as rnd import sys import warnings - +import pytest import mkl_fft reps_64 = (2**11)*np.finfo(np.float64).eps @@ -162,7 +162,7 @@ def test_gh64(self): a = np.arange(12).reshape((3,4)) x = a.astype(np.cdouble) # should executed successfully - r1 = mkl_fft.fftn(a, shape=None, axes=(-2,-1)) + r1 = mkl_fft.fftn(a, s=None, axes=(-2,-1)) r2 = mkl_fft.fftn(x) r_tol, a_tol = _get_rtol_atol(x) assert_allclose(r1, r2, rtol=r_tol, atol=a_tol) @@ -223,8 +223,43 @@ def test_gh109(): b_int = np.array([[5, 7, 6, 5], [4, 6, 4, 8], [9, 3, 7, 5]], dtype=np.int64) b = np.asarray(b_int, dtype=np.float32) - r1 = mkl_fft.fftn(b, shape=None, axes=(0,), overwrite_x=False, fwd_scale=1/3) - r2 = mkl_fft.fftn(b_int, shape=None, axes=(0,), overwrite_x=False, fwd_scale=1/3) + r1 = mkl_fft.fftn(b, s=None, axes=(0,), overwrite_x=False, fwd_scale=1/3) + r2 = mkl_fft.fftn(b_int, s=None, axes=(0,), overwrite_x=False, fwd_scale=1/3) rtol, atol = _get_rtol_atol(b) assert_allclose(r1, r2, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("dtype", [complex, float]) +@pytest.mark.parametrize("s", [(15, 24, 10), [35, 25, 15], [25, 15, 5]]) +@pytest.mark.parametrize("axes", [(0, 1, 2), (-1, -2, -3), [1, 0, 2]]) +@pytest.mark.parametrize("func", ["fftn", "ifftn", "rfftn", "irfftn"]) +def test_s_axes(dtype, s, axes, func): + shape = (30, 20, 10) + if dtype is complex and func != "rfftn": + x = np.random.random(shape) + 1j * np.random.random(shape) + else: + x = np.random.random(shape) + + r1 = getattr(mkl_fft, func)(x, s=s, axes=axes) + r2 = getattr(np.fft, func)(x, s=s, axes=axes) + + rtol, atol = _get_rtol_atol(x) + assert_allclose(r1, r2, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("dtype", [complex, float]) +@pytest.mark.parametrize("axes", [(2, 0, 2, 0), (0, 1, 1), (2, 0, 1, 3, 2, 1)]) +@pytest.mark.parametrize("func", ["rfftn", "irfftn"]) +def test_repeated_axes(dtype, axes, func): + shape = (2, 3, 4, 5) + if dtype is complex and func != "rfftn": + x = np.random.random(shape) + 1j * np.random.random(shape) + else: + x = np.random.random(shape) + + r1 = getattr(mkl_fft, func)(x, axes=axes) + r2 = getattr(np.fft, func)(x, axes=axes) + + rtol, atol = _get_rtol_atol(x) + assert_allclose(r1, r2, rtol=rtol, atol=atol) diff --git a/mkl_fft/tests/test_interfaces.py b/mkl_fft/tests/test_interfaces.py index 954ed56..221e6f5 100644 --- a/mkl_fft/tests/test_interfaces.py +++ b/mkl_fft/tests/test_interfaces.py @@ -29,14 +29,6 @@ import numpy as np -def test_interfaces_has_numpy(): - assert hasattr(mfi, 'numpy_fft') - - -def test_interfaces_has_scipy(): - assert hasattr(mfi, 'scipy_fft') - - @pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"]) @pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128]) def test_scipy_fft(norm, dtype): @@ -151,3 +143,15 @@ def test_scipy_fft_arg_validate(): with pytest.raises(NotImplementedError): mfi.scipy_fft.fft([1,2,3,4], plan="magic") + +@pytest.mark.parametrize( + "func", + [mfi.scipy_fft.rfft2, mfi.numpy_fft.rfft2], + ids=["scipy", "numpy"], +) +def test_axes(func): + x = np.arange(24.).reshape(2, 3, 4) + res = func(x, axes=(1, 2)) + exp = np.fft.rfft2(x, axes=(1, 2)) + tol = 64 * np.finfo(np.float64).eps + assert np.allclose(res, exp, atol=tol, rtol=tol) diff --git a/mkl_fft/tests/test_pocketfft.py b/mkl_fft/tests/test_pocketfft.py index aef1cdd..7f006a0 100644 --- a/mkl_fft/tests/test_pocketfft.py +++ b/mkl_fft/tests/test_pocketfft.py @@ -37,8 +37,7 @@ def test_identity(self): assert_allclose(mkl_fft.irfft(mkl_fft.rfft(xr[0:i]), i), xr[0:i], atol=1e-12) - @pytest.mark.skip() - @pytest.mark.parametrize("dtype", [np.single, np.double, np.longdouble]) + @pytest.mark.parametrize("dtype", [np.single, np.double]) #, np.longdouble]) def test_identity_long_short(self, dtype): # Test with explicitly given number of points, both for n # smaller and for n larger than the input size. @@ -56,8 +55,7 @@ def test_identity_long_short(self, dtype): assert check_r.dtype == dtype assert_allclose(check_r, xxr[0:i], atol=atol, rtol=0) - @pytest.mark.skip() - @pytest.mark.parametrize("dtype", [np.single, np.double, np.longdouble]) + @pytest.mark.parametrize("dtype", [np.single, np.double]) #, np.longdouble]) def test_identity_long_short_reversed(self, dtype): # Also test explicitly given number of points in reversed order. maxlen = 16 @@ -307,7 +305,6 @@ def test_irfft2(self): assert_allclose(x, mkl_fft.irfft2(mkl_fft.rfft2(x, norm="forward"), norm="forward"), atol=1e-6) - @pytest.mark.skip("repeated axes") def test_rfftn(self): x = random((30, 20, 10)) assert_allclose(mkl_fft.fftn(x)[:, :, :6], mkl_fft.rfftn(x), atol=1e-6) @@ -360,7 +357,6 @@ def test_ihfft(self): assert_allclose(x_herm, mkl_fft.ihfft(mkl_fft.hfft(x_herm, norm="forward"), norm="forward"), atol=1e-6) - @pytest.mark.skip("Casting complex values to real") @pytest.mark.parametrize("op", [mkl_fft.fftn, mkl_fft.ifftn, mkl_fft.rfftn, mkl_fft.irfftn]) def test_axes(self, op): @@ -483,7 +479,6 @@ def test_irfftn_out_and_s_interaction(self, s): assert_array_equal(result, expected) -@pytest.mark.skip() @pytest.mark.parametrize( "dtype", [np.float32, np.float64, np.complex64, np.complex128]) @@ -518,7 +513,7 @@ def test_fft_with_order(dtype, order, fft): for ax in axes: X_res = fft(X, axes=ax) Y_res = fft(Y, axes=ax) - assert_allclose(X_res, Y_res, atol=_tol, rtol=_tol) + assert_allclose(X_res, Y_res, atol=_tol, rtol=10 * _tol) else: raise ValueError @@ -591,7 +586,6 @@ def test_irfft_with_n_large_regression(): assert_allclose(result, expected) -@pytest.mark.skip() @pytest.mark.parametrize("fft", [ mkl_fft.fft, mkl_fft.ifft, mkl_fft.rfft, mkl_fft.irfft ]) @@ -605,4 +599,4 @@ def test_fft_with_integer_or_bool_input(data, fft): result = fft(data) float_data = data.astype(np.result_type(data, 1.)) expected = fft(float_data) - assert_array_equal(result, expected) + assert_allclose(result, expected, rtol=1e-15)