Skip to content

Commit b2af92d

Browse files
Merge pull request #56 from IntelPython/fix-issue-48
Fixed issue #48
2 parents cc946a1 + 6fd2760 commit b2af92d

File tree

3 files changed

+85
-5
lines changed

3 files changed

+85
-5
lines changed

mkl_fft/_numpy_fft.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
g#!/usr/bin/env python
1+
#!/usr/bin/env python
22
# Copyright (c) 2017-2019, Intel Corporation
33
#
44
# Redistribution and use in source and binary forms, with or without

mkl_fft/_pydfti.pyx

+65-4
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,48 @@ def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False, scale_f
912912
return a
913913

914914

915+
def flat_to_multi(ind, shape):
916+
nd = len(shape)
917+
m_ind = [-1] * nd
918+
j = ind
919+
for i in range(nd):
920+
si = shape[nd-1-i]
921+
q = j // si
922+
r = j - si * q
923+
m_ind[nd-1-i] = r
924+
j = q
925+
return m_ind
926+
927+
928+
def iter_complementary(x, axes, func, kwargs, result):
929+
if axes is None:
930+
return func(x, **kwargs)
931+
x_shape = x.shape
932+
nd = x.ndim
933+
r = list(range(nd))
934+
sl = [slice(None, None, None)] * nd
935+
if not isinstance(axes, tuple):
936+
axes = (axes,)
937+
for ai in axes:
938+
r[ai] = None
939+
size = 1
940+
sub_shape = []
941+
dual_ind = []
942+
for ri in r:
943+
if ri is not None:
944+
size *= x_shape[ri]
945+
sub_shape.append(x_shape[ri])
946+
dual_ind.append(ri)
947+
948+
for ind in range(size):
949+
m_ind = flat_to_multi(ind, sub_shape)
950+
for k1, k2 in zip(dual_ind, m_ind):
951+
sl[k1] = k2
952+
np.copyto(result[tuple(sl)], func(x[tuple(sl)], **kwargs))
953+
954+
return result
955+
956+
915957
def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
916958
"""Perform n-dimensional FFT over all axes"""
917959
cdef int err
@@ -988,6 +1030,7 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
9881030

9891031
return f_arr
9901032

1033+
9911034
def _check_shapes_for_direct(xs, shape, axes):
9921035
if len(axes) > 7: # Intel MKL supports up to 7D
9931036
return False
@@ -1006,6 +1049,14 @@ def _check_shapes_for_direct(xs, shape, axes):
10061049
return True
10071050

10081051

1052+
def _output_dtype(dt):
1053+
if dt == np.double:
1054+
return np.cdouble
1055+
if dt == np.single:
1056+
return np.csingle
1057+
return dt
1058+
1059+
10091060
def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, double fsc=1.0):
10101061
if direction not in [-1, +1]:
10111062
raise ValueError("Direction of FFT should +1 or -1")
@@ -1026,10 +1077,20 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, doubl
10261077
if _direct:
10271078
return _direct_fftnd(x, overwrite_arg=overwrite_x, direction=direction, fsc=fsc)
10281079
else:
1029-
sc = (<object> fsc)**(1/x.ndim)
1030-
return _iter_fftnd(x, s=shape, axes=axes,
1031-
overwrite_arg=overwrite_x, scale_function=lambda n: sc,
1032-
function=fft if direction == 1 else ifft)
1080+
if (shape is None):
1081+
x = np.asarray(x)
1082+
res = np.empty(x.shape, dtype=_output_dtype(x.dtype))
1083+
return iter_complementary(
1084+
x, axes,
1085+
_direct_fftnd,
1086+
{'overwrite_arg': overwrite_x, 'direction': direction, 'fsc': fsc},
1087+
res
1088+
)
1089+
else:
1090+
sc = (<object> fsc)**(1/x.ndim)
1091+
return _iter_fftnd(x, s=shape, axes=axes,
1092+
overwrite_arg=overwrite_x, scale_function=lambda n: sc,
1093+
function=fft if direction == 1 else ifft)
10331094

10341095

10351096
def fft2(x, shape=None, axes=(-2,-1), overwrite_x=False, forward_scale=1.0):

mkl_fft/tests/test_fftnd.py

+19
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,24 @@ def test_matrix4(self):
9999
assert_allclose(t_strided, t_contig, rtol=r_tol, atol=a_tol)
100100

101101

102+
def test_matrix5(self):
103+
"""fftn of strided array is same as fftn of a contiguous copy"""
104+
rs = rnd.RandomState(1234)
105+
x = rs.randn(6, 11, 12, 13)
106+
y = x[::-2, :, :, ::3]
107+
r_tol, a_tol = _get_rtol_atol(y)
108+
f = mkl_fft.fftn(y, axes=(1,2))
109+
for i0 in range(y.shape[0]):
110+
for i3 in range(y.shape[3]):
111+
assert_allclose(
112+
f[i0, :, :, i3],
113+
mkl_fft.fftn(y[i0, :, : , i3]),
114+
rtol=r_tol, atol=a_tol
115+
)
116+
117+
118+
119+
102120
class Test_Regressions(TestCase):
103121

104122
def setUp(self):
@@ -129,6 +147,7 @@ def test_rfftn_numpy(self):
129147
tr_rfft = np.transpose(mkl_fft.rfftn_numpy(x, axes=a), a)
130148
assert_allclose(rfft_tr, tr_rfft, rtol=r_tol, atol=a_tol)
131149

150+
132151
class Test_Scales(TestCase):
133152
def setUp(self):
134153
pass

0 commit comments

Comments
 (0)