@@ -8118,42 +8118,10 @@ def aten_std_mean_correction(
81188118 return op .Sqrt (var ), mean
81198119
81208120
8121- @torch_op ("aten::stft" , private = True )
8122- def _add_batch_dimension (self : TFloatOrBFloat16 ) -> Tuple [TFloatOrBFloat16 , INT64 ]:
8123- signal_rank = Rank (self )
8124- if signal_rank == 1 :
8125- # Add a batch dimension
8126- self = op .Unsqueeze (self , op .Constant (value_ints = [0 ]))
8127- return op .Identity (self ), signal_rank
8128-
8129-
8130- @torch_op ("aten::stft" , private = True )
8131- def _center_window_around_zeros_if_needed (
8132- window : TFloatOrBFloat16 , n_fft : int
8133- ) -> TFloatOrBFloat16 :
8134- # first dimension
8135- n_win = op .Shape (window , start = 0 , end = 1 )
8136- # Center window around zeros if needed (required by ONNX's STFT)
8137- if n_win < n_fft :
8138- left = (n_fft - n_win ) / 2
8139-
8140- right = n_fft - left - n_win
8141- left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8142- right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8143-
8144- left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8145- right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8146- right_win = op .CastLike (right_win , window )
8147- left_win = op .CastLike (left_win , window )
8148- window = op .Concat (left_win , window , right_win , axis = 0 )
8149- return window
8150-
8151-
8152- @torch_op ("aten::stft" , private = True )
8153- def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloatOrBFloat16 :
8154- left = (n_fft - win_length ) / 2
8121+ def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloat :
8122+ left = op .Div (op .Sub (n_fft , win_length ), op .Constant (value_ints = [2 ]))
81558123
8156- right = n_fft - left - win_length
8124+ right = op . Sub ( op . Sub ( n_fft , left ), win_length )
81578125 left = op .Reshape (left , op .Constant (value_ints = [1 ]))
81588126 right = op .Reshape (right , op .Constant (value_ints = [1 ]))
81598127 win_length = op .Reshape (win_length , op .Constant (value_ints = [1 ]))
@@ -8164,71 +8132,66 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloa
81648132 return op .Concat (left_win , window_list , right_win , axis = 0 )
81658133
81668134
8167- @torch_op ("aten::stft" , private = True )
8168- def _create_window_from_n_fft (n_fft : int ) -> TFloatOrBFloat16 :
8135+ def _create_window_from_n_fft (n_fft : int ) -> TFloat :
81698136 n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
81708137 window = op .Expand (op .Constant (value_ints = [1 ]), n_fft_tensor )
81718138 return window
81728139
81738140
8174- @torch_op ("aten::stft" , private = True )
8175- def _normalize_fft_result (
8176- signal : TFloatOrBFloat16 , result : TFloatOrBFloat16 , n_fft : int
8177- ) -> TFloatOrBFloat16 :
8141+ def _normalize_fft_result (signal : TFloat , result : TFloat , n_fft : int ) -> TFloat :
81788142 n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
81798143 sqrt_nfft = op .Sqrt (op .CastLike (n_fft_tensor , signal ))
8180- result = result / sqrt_nfft
8181- return result
8182-
8183-
8184- @torch_op ("aten::stft" , private = True )
8185- def _aten_stft_onnx (
8186- signal : TFloatOrBFloat16 ,
8187- frame_step_const : INT64 ,
8188- window : Union [TFloatOrBFloat16 , INT64 ],
8189- frame_length_const : INT64 ,
8190- signal_rank : INT64 ,
8191- onesided : int ,
8192- ) -> TFloatOrBFloat16 :
8193- window = op .CastLike (window , signal )
8194- result = op .STFT (signal , frame_step_const , window , frame_length_const , onesided = onesided )
8195- result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8196- # Remove batch dimension, if needed
8197- if signal_rank == 1 :
8198- result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
8144+ result = op .Div (result , sqrt_nfft )
81998145 return result
82008146
82018147
82028148@torch_op ("aten::stft" , trace_only = True )
82038149def aten_stft (
8204- self : TFloatOrBFloat16 ,
8150+ self : TFloat ,
82058151 n_fft : int ,
82068152 hop_length : Optional [int ] = None ,
82078153 win_length : Optional [int ] = None ,
8208- window : Optional [TFloatOrBFloat16 ] = None ,
8154+ window : Optional [TFloat ] = None ,
82098155 normalized : bool = False ,
82108156 onesided : Optional [bool ] = None ,
82118157 return_complex : Optional [bool ] = None ,
8212- ) -> TFloatOrBFloat16 :
8158+ ) -> TFloat :
82138159 """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
82148160
8215- # NOTE: regarless of the value of return_complex, we always return a real representation.
8161+ # NOTE: regardless of the value of return_complex, we always return a real representation.
82168162 del return_complex
82178163
82188164 # Get STFT sizes
82198165 if hop_length is None :
82208166 # core dump
8221- # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8167+ # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
82228168 hop_length = n_fft // 4
82238169 frame_step_const = op .Reshape (hop_length , op .Constant (value_ints = [1 ]))
82248170 frame_length_const = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
82258171
82268172 # Pre-process input if needed
8227- self , signal_rank = _add_batch_dimension (self )
8173+ is_signal_rank1 = len (self .shape ) == 1
8174+ if is_signal_rank1 :
8175+ # Add a batch dimension
8176+ self = op .Identity (op .Unsqueeze (self , op .Constant (value_ints = [0 ])))
82288177
82298178 # Get window and make sure it's the same size as `win_length` or `n_fft`
82308179 if window is not None and window .shape [0 ] is not None :
8231- window = _center_window_around_zeros_if_needed (window , n_fft )
8180+ # first dimension
8181+ n_win = op .Shape (window , start = 0 , end = 1 )
8182+ # Center window around zeros if needed (required by ONNX's STFT)
8183+ if n_win < n_fft :
8184+ left = op .Div (op .Sub (n_fft , n_win ), op .Constant (value_ints = [2 ]))
8185+
8186+ right = op .Sub (op .Sub (n_fft , left ), n_win )
8187+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8188+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8189+
8190+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8191+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8192+ right_win = op .CastLike (right_win , window )
8193+ left_win = op .CastLike (left_win , window )
8194+ window = op .Concat (left_win , window , right_win , axis = 0 )
82328195 elif window is None :
82338196 if win_length is not None :
82348197 window = _create_window_from_win_length (win_length , n_fft )
@@ -8239,10 +8202,12 @@ def aten_stft(
82398202 onesided = 1
82408203 else :
82418204 onesided = 0
8242- # remove batch dimension included
8243- result = _aten_stft_onnx (
8244- self , frame_step_const , window , frame_length_const , signal_rank , onesided
8245- )
8205+ window = op .CastLike (window , self )
8206+ result = op .STFT (self , frame_step_const , window , frame_length_const , onesided = onesided )
8207+ result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8208+ # Remove batch dimension, if needed
8209+ if is_signal_rank1 :
8210+ result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
82468211
82478212 # Normalize, if needed
82488213 if normalized :
0 commit comments