@@ -8128,42 +8128,10 @@ def aten_std_mean_correction(
81288128 return op .Sqrt (var ), mean
81298129
81308130
8131- @torch_op ("aten::stft" , private = True )
8132- def _add_batch_dimension (self : TFloatOrBFloat16 ) -> Tuple [TFloatOrBFloat16 , INT64 ]:
8133- signal_rank = Rank (self )
8134- if signal_rank == 1 :
8135- # Add a batch dimension
8136- self = op .Unsqueeze (self , op .Constant (value_ints = [0 ]))
8137- return op .Identity (self ), signal_rank
8138-
8139-
8140- @torch_op ("aten::stft" , private = True )
8141- def _center_window_around_zeros_if_needed (
8142- window : TFloatOrBFloat16 , n_fft : int
8143- ) -> TFloatOrBFloat16 :
8144- # first dimension
8145- n_win = op .Shape (window , start = 0 , end = 1 )
8146- # Center window around zeros if needed (required by ONNX's STFT)
8147- if n_win < n_fft :
8148- left = (n_fft - n_win ) / 2
8149-
8150- right = n_fft - left - n_win
8151- left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8152- right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8153-
8154- left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8155- right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8156- right_win = op .CastLike (right_win , window )
8157- left_win = op .CastLike (left_win , window )
8158- window = op .Concat (left_win , window , right_win , axis = 0 )
8159- return window
8160-
8161-
8162- @torch_op ("aten::stft" , private = True )
8163- def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloatOrBFloat16 :
8164- left = (n_fft - win_length ) / 2
8131+ def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloat :
8132+ left = op .Div (op .Sub (n_fft , win_length ), op .Constant (value_ints = [2 ]))
81658133
8166- right = n_fft - left - win_length
8134+ right = op . Sub ( op . Sub ( n_fft , left ), win_length )
81678135 left = op .Reshape (left , op .Constant (value_ints = [1 ]))
81688136 right = op .Reshape (right , op .Constant (value_ints = [1 ]))
81698137 win_length = op .Reshape (win_length , op .Constant (value_ints = [1 ]))
@@ -8174,71 +8142,66 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloa
81748142 return op .Concat (left_win , window_list , right_win , axis = 0 )
81758143
81768144
8177- @torch_op ("aten::stft" , private = True )
8178- def _create_window_from_n_fft (n_fft : int ) -> TFloatOrBFloat16 :
8145+ def _create_window_from_n_fft (n_fft : int ) -> TFloat :
81798146 n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
81808147 window = op .Expand (op .Constant (value_ints = [1 ]), n_fft_tensor )
81818148 return window
81828149
81838150
8184- @torch_op ("aten::stft" , private = True )
8185- def _normalize_fft_result (
8186- signal : TFloatOrBFloat16 , result : TFloatOrBFloat16 , n_fft : int
8187- ) -> TFloatOrBFloat16 :
8151+ def _normalize_fft_result (signal : TFloat , result : TFloat , n_fft : int ) -> TFloat :
81888152 n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
81898153 sqrt_nfft = op .Sqrt (op .CastLike (n_fft_tensor , signal ))
8190- result = result / sqrt_nfft
8191- return result
8192-
8193-
8194- @torch_op ("aten::stft" , private = True )
8195- def _aten_stft_onnx (
8196- signal : TFloatOrBFloat16 ,
8197- frame_step_const : INT64 ,
8198- window : Union [TFloatOrBFloat16 , INT64 ],
8199- frame_length_const : INT64 ,
8200- signal_rank : INT64 ,
8201- onesided : int ,
8202- ) -> TFloatOrBFloat16 :
8203- window = op .CastLike (window , signal )
8204- result = op .STFT (signal , frame_step_const , window , frame_length_const , onesided = onesided )
8205- result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8206- # Remove batch dimension, if needed
8207- if signal_rank == 1 :
8208- result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
8154+ result = op .Div (result , sqrt_nfft )
82098155 return result
82108156
82118157
82128158@torch_op ("aten::stft" , trace_only = True )
82138159def aten_stft (
8214- self : TFloatOrBFloat16 ,
8160+ self : TFloat ,
82158161 n_fft : int ,
82168162 hop_length : Optional [int ] = None ,
82178163 win_length : Optional [int ] = None ,
8218- window : Optional [TFloatOrBFloat16 ] = None ,
8164+ window : Optional [TFloat ] = None ,
82198165 normalized : bool = False ,
82208166 onesided : Optional [bool ] = None ,
82218167 return_complex : Optional [bool ] = None ,
8222- ) -> TFloatOrBFloat16 :
8168+ ) -> TFloat :
82238169 """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
82248170
8225- # NOTE: regarless of the value of return_complex, we always return a real representation.
8171+ # NOTE: regardless of the value of return_complex, we always return a real representation.
82268172 del return_complex
82278173
82288174 # Get STFT sizes
82298175 if hop_length is None :
82308176 # core dump
8231- # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8177+ # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
82328178 hop_length = n_fft // 4
82338179 frame_step_const = op .Reshape (hop_length , op .Constant (value_ints = [1 ]))
82348180 frame_length_const = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
82358181
82368182 # Pre-process input if needed
8237- self , signal_rank = _add_batch_dimension (self )
8183+ is_signal_rank1 = len (self .shape ) == 1
8184+ if is_signal_rank1 :
8185+ # Add a batch dimension
8186+ self = op .Identity (op .Unsqueeze (self , op .Constant (value_ints = [0 ])))
82388187
82398188 # Get window and make sure it's the same size as `win_length` or `n_fft`
82408189 if window is not None and window .shape [0 ] is not None :
8241- window = _center_window_around_zeros_if_needed (window , n_fft )
8190+ # first dimension
8191+ n_win = op .Shape (window , start = 0 , end = 1 )
8192+ # Center window around zeros if needed (required by ONNX's STFT)
8193+ if n_win < n_fft :
8194+ left = op .Div (op .Sub (n_fft , n_win ), op .Constant (value_ints = [2 ]))
8195+
8196+ right = op .Sub (op .Sub (n_fft , left ), n_win )
8197+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8198+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8199+
8200+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8201+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8202+ right_win = op .CastLike (right_win , window )
8203+ left_win = op .CastLike (left_win , window )
8204+ window = op .Concat (left_win , window , right_win , axis = 0 )
82428205 elif window is None :
82438206 if win_length is not None :
82448207 window = _create_window_from_win_length (win_length , n_fft )
@@ -8249,10 +8212,12 @@ def aten_stft(
82498212 onesided = 1
82508213 else :
82518214 onesided = 0
8252- # remove batch dimension included
8253- result = _aten_stft_onnx (
8254- self , frame_step_const , window , frame_length_const , signal_rank , onesided
8255- )
8215+ window = op .CastLike (window , self )
8216+ result = op .STFT (self , frame_step_const , window , frame_length_const , onesided = onesided )
8217+ result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8218+ # Remove batch dimension, if needed
8219+ if is_signal_rank1 :
8220+ result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
82568221
82578222 # Normalize, if needed
82588223 if normalized :
0 commit comments