@@ -8548,6 +8548,139 @@ def aten_std_mean_correction(
85488548 return op .Sqrt (var ), mean
85498549
85508550
8551+ @torch_op ("aten::stft" , private = True )
8552+ def _add_batch_dimension (self : TFloatOrBFloat16 ) -> Tuple [TFloatOrBFloat16 , INT64 ]:
8553+ signal_rank = Rank (self )
8554+ if signal_rank == 1 :
8555+ # Add a batch dimension
8556+ self = op .Unsqueeze (self , op .Constant (value_ints = [0 ]))
8557+ return op .Identity (self ), signal_rank
8558+
8559+
8560+ @torch_op ("aten::stft" , private = True )
8561+ def _center_window_around_zeros_if_needed (
8562+ window : TFloatOrBFloat16 , n_fft : int
8563+ ) -> TFloatOrBFloat16 :
8564+ # first dimension
8565+ n_win = op .Shape (window , start = 0 , end = 1 )
8566+ # Center window around zeros if needed (required by ONNX's STFT)
8567+ if n_win < n_fft :
8568+ left = (n_fft - n_win ) / 2
8569+
8570+ right = n_fft - left - n_win
8571+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8572+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8573+
8574+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8575+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8576+ right_win = op .CastLike (right_win , window )
8577+ left_win = op .CastLike (left_win , window )
8578+ window = op .Concat (left_win , window , right_win , axis = 0 )
8579+ return window
8580+
8581+
8582+ @torch_op ("aten::stft" , private = True )
8583+ def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloatOrBFloat16 :
8584+ left = (n_fft - win_length ) / 2
8585+
8586+ right = n_fft - left - win_length
8587+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8588+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8589+ win_length = op .Reshape (win_length , op .Constant (value_ints = [1 ]))
8590+
8591+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8592+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8593+ window_list = op .Expand (op .Constant (value_ints = [1 ]), win_length )
8594+ return op .Concat (left_win , window_list , right_win , axis = 0 )
8595+
8596+
8597+ @torch_op ("aten::stft" , private = True )
8598+ def _create_window_from_n_fft (n_fft : int ) -> TFloatOrBFloat16 :
8599+ n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8600+ window = op .Expand (op .Constant (value_ints = [1 ]), n_fft_tensor )
8601+ return window
8602+
8603+
8604+ @torch_op ("aten::stft" , private = True )
8605+ def _normalize_fft_result (
8606+ signal : TFloatOrBFloat16 , result : TFloatOrBFloat16 , n_fft : int
8607+ ) -> TFloatOrBFloat16 :
8608+ n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8609+ sqrt_nfft = op .Sqrt (op .CastLike (n_fft_tensor , signal ))
8610+ result = result / sqrt_nfft
8611+ return result
8612+
8613+
8614+ @torch_op ("aten::stft" , private = True )
8615+ def _aten_stft_onnx (
8616+ signal : TFloatOrBFloat16 ,
8617+ frame_step_const : INT64 ,
8618+ window : Union [TFloatOrBFloat16 , INT64 ],
8619+ frame_length_const : INT64 ,
8620+ signal_rank : INT64 ,
8621+ onesided : int ,
8622+ ) -> TFloatOrBFloat16 :
8623+ window = op .CastLike (window , signal )
8624+ result = op .STFT (signal , frame_step_const , window , frame_length_const , onesided = onesided )
8625+ result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8626+ # Remove batch dimension, if needed
8627+ if signal_rank == 1 :
8628+ result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
8629+ return result
8630+
8631+
8632+ @torch_op ("aten::stft" , trace_only = True )
8633+ def aten_stft (
8634+ self : TFloatOrBFloat16 ,
8635+ n_fft : int ,
8636+ hop_length : Optional [int ] = None ,
8637+ win_length : Optional [int ] = None ,
8638+ window : Optional [TFloatOrBFloat16 ] = None ,
8639+ normalized : bool = False ,
8640+ onesided : Optional [bool ] = None ,
8641+ return_complex : Optional [bool ] = None ,
8642+ ) -> TFloatOrBFloat16 :
8643+ """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
8644+
8645+ # NOTE: regarless of the value of return_complex, we always return a real representation.
8646+ del return_complex
8647+
8648+ # Get STFT sizes
8649+ if hop_length is None :
8650+ # core dump
8651+ # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8652+ hop_length = n_fft // 4
8653+ frame_step_const = op .Reshape (hop_length , op .Constant (value_ints = [1 ]))
8654+ frame_length_const = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8655+
8656+ # Pre-process input if needed
8657+ self , signal_rank = _add_batch_dimension (self )
8658+
8659+ # Get window and make sure it's the same size as `win_length` or `n_fft`
8660+ if window is not None and window .shape [0 ] is not None :
8661+ window = _center_window_around_zeros_if_needed (window , n_fft )
8662+ elif window is None :
8663+ if win_length is not None :
8664+ window = _create_window_from_win_length (win_length , n_fft )
8665+ else :
8666+ window = _create_window_from_n_fft (n_fft )
8667+
8668+ if onesided is None or onesided :
8669+ onesided = 1
8670+ else :
8671+ onesided = 0
8672+ # remove batch dimension included
8673+ result = _aten_stft_onnx (
8674+ self , frame_step_const , window , frame_length_const , signal_rank , onesided
8675+ )
8676+
8677+ # Normalize, if needed
8678+ if normalized :
8679+ result = _normalize_fft_result (self , result , n_fft )
8680+
8681+ return result
8682+
8683+
85518684@torch_op (
85528685 (
85538686 "aten::sub.Tensor" ,
0 commit comments