@@ -912,6 +912,48 @@ def _iter_fftnd(a, s=None, axes=None, function=fft, overwrite_arg=False, scale_f
912
912
return a
913
913
914
914
915
+ def flat_to_multi (ind , shape ):
916
+ nd = len (shape )
917
+ m_ind = [- 1 ] * nd
918
+ j = ind
919
+ for i in range (nd ):
920
+ si = shape [nd - 1 - i ]
921
+ q = j // si
922
+ r = j - si * q
923
+ m_ind [nd - 1 - i ] = r
924
+ j = q
925
+ return m_ind
926
+
927
+
928
+ def iter_complementary (x , axes , func , kwargs , result ):
929
+ if axes is None :
930
+ return func (x , ** kwargs )
931
+ x_shape = x .shape
932
+ nd = x .ndim
933
+ r = list (range (nd ))
934
+ sl = [slice (None , None , None )] * nd
935
+ if not isinstance (axes , tuple ):
936
+ axes = (axes ,)
937
+ for ai in axes :
938
+ r [ai ] = None
939
+ size = 1
940
+ sub_shape = []
941
+ dual_ind = []
942
+ for ri in r :
943
+ if ri is not None :
944
+ size * = x_shape [ri ]
945
+ sub_shape .append (x_shape [ri ])
946
+ dual_ind .append (ri )
947
+
948
+ for ind in range (size ):
949
+ m_ind = flat_to_multi (ind , sub_shape )
950
+ for k1 , k2 in zip (dual_ind , m_ind ):
951
+ sl [k1 ] = k2
952
+ np .copyto (result [tuple (sl )], func (x [tuple (sl )], ** kwargs ))
953
+
954
+ return result
955
+
956
+
915
957
def _direct_fftnd (x , overwrite_arg = False , direction = + 1 , double fsc = 1.0 ):
916
958
"""Perform n-dimensional FFT over all axes"""
917
959
cdef int err
@@ -988,6 +1030,7 @@ def _direct_fftnd(x, overwrite_arg=False, direction=+1, double fsc=1.0):
988
1030
989
1031
return f_arr
990
1032
1033
+
991
1034
def _check_shapes_for_direct (xs , shape , axes ):
992
1035
if len (axes ) > 7 : # Intel MKL supports up to 7D
993
1036
return False
@@ -1006,6 +1049,14 @@ def _check_shapes_for_direct(xs, shape, axes):
1006
1049
return True
1007
1050
1008
1051
1052
+ def _output_dtype (dt ):
1053
+ if dt == np .double :
1054
+ return np .cdouble
1055
+ if dt == np .single :
1056
+ return np .csingle
1057
+ return dt
1058
+
1059
+
1009
1060
def _fftnd_impl (x , shape = None , axes = None , overwrite_x = False , direction = + 1 , double fsc = 1.0 ):
1010
1061
if direction not in [- 1 , + 1 ]:
1011
1062
raise ValueError ("Direction of FFT should +1 or -1" )
@@ -1026,10 +1077,20 @@ def _fftnd_impl(x, shape=None, axes=None, overwrite_x=False, direction=+1, doubl
1026
1077
if _direct :
1027
1078
return _direct_fftnd (x , overwrite_arg = overwrite_x , direction = direction , fsc = fsc )
1028
1079
else :
1029
- sc = (< object > fsc )** (1 / x .ndim )
1030
- return _iter_fftnd (x , s = shape , axes = axes ,
1031
- overwrite_arg = overwrite_x , scale_function = lambda n : sc ,
1032
- function = fft if direction == 1 else ifft )
1080
+ if (shape is None ):
1081
+ x = np .asarray (x )
1082
+ res = np .empty (x .shape , dtype = _output_dtype (x .dtype ))
1083
+ return iter_complementary (
1084
+ x , axes ,
1085
+ _direct_fftnd ,
1086
+ {'overwrite_arg' : overwrite_x , 'direction' : direction , 'fsc' : fsc },
1087
+ res
1088
+ )
1089
+ else :
1090
+ sc = (< object > fsc )** (1 / x .ndim )
1091
+ return _iter_fftnd (x , s = shape , axes = axes ,
1092
+ overwrite_arg = overwrite_x , scale_function = lambda n : sc ,
1093
+ function = fft if direction == 1 else ifft )
1033
1094
1034
1095
1035
1096
def fft2 (x , shape = None , axes = (- 2 ,- 1 ), overwrite_x = False , forward_scale = 1.0 ):
0 commit comments