@@ -8548,6 +8548,103 @@ def aten_std_mean_correction(
85488548 return op .Sqrt (var ), mean
85498549
85508550
8551+ def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloat :
8552+ left = op .Div (op .Sub (n_fft , win_length ), op .Constant (value_ints = [2 ]))
8553+
8554+ right = op .Sub (op .Sub (n_fft , left ), win_length )
8555+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8556+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8557+ win_length = op .Reshape (win_length , op .Constant (value_ints = [1 ]))
8558+
8559+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8560+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8561+ window_list = op .Expand (op .Constant (value_ints = [1 ]), win_length )
8562+ return op .Concat (left_win , window_list , right_win , axis = 0 )
8563+
8564+
8565+ def _create_window_from_n_fft (n_fft : int ) -> TFloat :
8566+ n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8567+ window = op .Expand (op .Constant (value_ints = [1 ]), n_fft_tensor )
8568+ return window
8569+
8570+
8571+ def _normalize_fft_result (signal : TFloat , result : TFloat , n_fft : int ) -> TFloat :
8572+ n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8573+ sqrt_nfft = op .Sqrt (op .CastLike (n_fft_tensor , signal ))
8574+ result = op .Div (result , sqrt_nfft )
8575+ return result
8576+
8577+
8578+ @torch_op ("aten::stft" , trace_only = True )
8579+ def aten_stft (
8580+ self : TFloat ,
8581+ n_fft : int ,
8582+ hop_length : Optional [int ] = None ,
8583+ win_length : Optional [int ] = None ,
8584+ window : Optional [TFloat ] = None ,
8585+ normalized : bool = False ,
8586+ onesided : Optional [bool ] = None ,
8587+ return_complex : Optional [bool ] = None ,
8588+ ) -> TFloat :
8589+ """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
8590+
8591+ # NOTE: regardless of the value of return_complex, we always return a real representation.
8592+ del return_complex
8593+
8594+ # Get STFT sizes
8595+ if hop_length is None :
8596+ # core dump
8597+ # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8598+ hop_length = n_fft // 4
8599+ frame_step_const = op .Reshape (hop_length , op .Constant (value_ints = [1 ]))
8600+
8601+ # Pre-process input if needed
8602+ is_signal_rank1 = len (self .shape ) == 1
8603+ if is_signal_rank1 :
8604+ # Add a batch dimension
8605+ self = op .Identity (op .Unsqueeze (self , op .Constant (value_ints = [0 ])))
8606+
8607+ # Get window and make sure it's the same size as `win_length` or `n_fft`
8608+ if window is not None and window .shape [0 ] is not None :
8609+ # first dimension
8610+ n_win = op .Shape (window , start = 0 , end = 1 )
8611+ # Center window around zeros if needed (required by ONNX's STFT)
8612+ if n_win < n_fft :
8613+ left = op .Div (op .Sub (n_fft , n_win ), op .Constant (value_ints = [2 ]))
8614+
8615+ right = op .Sub (op .Sub (n_fft , left ), n_win )
8616+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8617+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8618+
8619+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8620+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8621+ right_win = op .CastLike (right_win , window )
8622+ left_win = op .CastLike (left_win , window )
8623+ window = op .Concat (left_win , window , right_win , axis = 0 )
8624+ elif window is None :
8625+ if win_length is not None :
8626+ window = _create_window_from_win_length (win_length , n_fft )
8627+ else :
8628+ window = _create_window_from_n_fft (n_fft )
8629+
8630+ if onesided is None or onesided :
8631+ onesided = 1
8632+ else :
8633+ onesided = 0
8634+ window = op .CastLike (window , self )
8635+ result = op .STFT (self , frame_step_const , window , n_fft , onesided = onesided )
8636+ result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8637+ # Remove batch dimension, if needed
8638+ if is_signal_rank1 :
8639+ result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
8640+
8641+ # Normalize, if needed
8642+ if normalized :
8643+ result = _normalize_fft_result (self , result , n_fft )
8644+
8645+ return result
8646+
8647+
85518648@torch_op (
85528649 (
85538650 "aten::sub.Tensor" ,
0 commit comments