@@ -8548,42 +8548,10 @@ def aten_std_mean_correction(
85488548 return op .Sqrt (var ), mean
85498549
85508550
8551- @torch_op ("aten::stft" , private = True )
8552- def _add_batch_dimension (self : TFloatOrBFloat16 ) -> Tuple [TFloatOrBFloat16 , INT64 ]:
8553- signal_rank = Rank (self )
8554- if signal_rank == 1 :
8555- # Add a batch dimension
8556- self = op .Unsqueeze (self , op .Constant (value_ints = [0 ]))
8557- return op .Identity (self ), signal_rank
8558-
8559-
8560- @torch_op ("aten::stft" , private = True )
8561- def _center_window_around_zeros_if_needed (
8562- window : TFloatOrBFloat16 , n_fft : int
8563- ) -> TFloatOrBFloat16 :
8564- # first dimension
8565- n_win = op .Shape (window , start = 0 , end = 1 )
8566- # Center window around zeros if needed (required by ONNX's STFT)
8567- if n_win < n_fft :
8568- left = (n_fft - n_win ) / 2
8569-
8570- right = n_fft - left - n_win
8571- left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8572- right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8573-
8574- left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8575- right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8576- right_win = op .CastLike (right_win , window )
8577- left_win = op .CastLike (left_win , window )
8578- window = op .Concat (left_win , window , right_win , axis = 0 )
8579- return window
8580-
8581-
8582- @torch_op ("aten::stft" , private = True )
8583- def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloatOrBFloat16 :
8584- left = (n_fft - win_length ) / 2
8551+ def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloat :
8552+ left = op .Div (op .Sub (n_fft , win_length ), op .Constant (value_ints = [2 ]))
85858553
8586- right = n_fft - left - win_length
8554+ right = op . Sub ( op . Sub ( n_fft , left ), win_length )
85878555 left = op .Reshape (left , op .Constant (value_ints = [1 ]))
85888556 right = op .Reshape (right , op .Constant (value_ints = [1 ]))
85898557 win_length = op .Reshape (win_length , op .Constant (value_ints = [1 ]))
@@ -8594,71 +8562,66 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloa
85948562 return op .Concat (left_win , window_list , right_win , axis = 0 )
85958563
85968564
8597- @torch_op ("aten::stft" , private = True )
8598- def _create_window_from_n_fft (n_fft : int ) -> TFloatOrBFloat16 :
8565+ def _create_window_from_n_fft (n_fft : int ) -> TFloat :
85998566 n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
86008567 window = op .Expand (op .Constant (value_ints = [1 ]), n_fft_tensor )
86018568 return window
86028569
86038570
8604- @torch_op ("aten::stft" , private = True )
8605- def _normalize_fft_result (
8606- signal : TFloatOrBFloat16 , result : TFloatOrBFloat16 , n_fft : int
8607- ) -> TFloatOrBFloat16 :
8571+ def _normalize_fft_result (signal : TFloat , result : TFloat , n_fft : int ) -> TFloat :
86088572 n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
86098573 sqrt_nfft = op .Sqrt (op .CastLike (n_fft_tensor , signal ))
8610- result = result / sqrt_nfft
8611- return result
8612-
8613-
8614- @torch_op ("aten::stft" , private = True )
8615- def _aten_stft_onnx (
8616- signal : TFloatOrBFloat16 ,
8617- frame_step_const : INT64 ,
8618- window : Union [TFloatOrBFloat16 , INT64 ],
8619- frame_length_const : INT64 ,
8620- signal_rank : INT64 ,
8621- onesided : int ,
8622- ) -> TFloatOrBFloat16 :
8623- window = op .CastLike (window , signal )
8624- result = op .STFT (signal , frame_step_const , window , frame_length_const , onesided = onesided )
8625- result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8626- # Remove batch dimension, if needed
8627- if signal_rank == 1 :
8628- result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
8574+ result = op .Div (result , sqrt_nfft )
86298575 return result
86308576
86318577
86328578@torch_op ("aten::stft" , trace_only = True )
86338579def aten_stft (
8634- self : TFloatOrBFloat16 ,
8580+ self : TFloat ,
86358581 n_fft : int ,
86368582 hop_length : Optional [int ] = None ,
86378583 win_length : Optional [int ] = None ,
8638- window : Optional [TFloatOrBFloat16 ] = None ,
8584+ window : Optional [TFloat ] = None ,
86398585 normalized : bool = False ,
86408586 onesided : Optional [bool ] = None ,
86418587 return_complex : Optional [bool ] = None ,
8642- ) -> TFloatOrBFloat16 :
8588+ ) -> TFloat :
86438589 """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
86448590
8645- # NOTE: regarless of the value of return_complex, we always return a real representation.
8591+ # NOTE: regardless of the value of return_complex, we always return a real representation.
86468592 del return_complex
86478593
86488594 # Get STFT sizes
86498595 if hop_length is None :
86508596 # core dump
8651- # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8597+ # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
86528598 hop_length = n_fft // 4
86538599 frame_step_const = op .Reshape (hop_length , op .Constant (value_ints = [1 ]))
86548600 frame_length_const = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
86558601
86568602 # Pre-process input if needed
8657- self , signal_rank = _add_batch_dimension (self )
8603+ is_signal_rank1 = len (self .shape ) == 1
8604+ if is_signal_rank1 :
8605+ # Add a batch dimension
8606+ self = op .Identity (op .Unsqueeze (self , op .Constant (value_ints = [0 ])))
86588607
86598608 # Get window and make sure it's the same size as `win_length` or `n_fft`
86608609 if window is not None and window .shape [0 ] is not None :
8661- window = _center_window_around_zeros_if_needed (window , n_fft )
8610+ # first dimension
8611+ n_win = op .Shape (window , start = 0 , end = 1 )
8612+ # Center window around zeros if needed (required by ONNX's STFT)
8613+ if n_win < n_fft :
8614+ left = op .Div (op .Sub (n_fft , n_win ), op .Constant (value_ints = [2 ]))
8615+
8616+ right = op .Sub (op .Sub (n_fft , left ), n_win )
8617+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8618+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8619+
8620+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8621+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8622+ right_win = op .CastLike (right_win , window )
8623+ left_win = op .CastLike (left_win , window )
8624+ window = op .Concat (left_win , window , right_win , axis = 0 )
86628625 elif window is None :
86638626 if win_length is not None :
86648627 window = _create_window_from_win_length (win_length , n_fft )
@@ -8669,10 +8632,12 @@ def aten_stft(
86698632 onesided = 1
86708633 else :
86718634 onesided = 0
8672- # remove batch dimension included
8673- result = _aten_stft_onnx (
8674- self , frame_step_const , window , frame_length_const , signal_rank , onesided
8675- )
8635+ window = op .CastLike (window , self )
8636+ result = op .STFT (self , frame_step_const , window , frame_length_const , onesided = onesided )
8637+ result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8638+ # Remove batch dimension, if needed
8639+ if is_signal_rank1 :
8640+ result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
86768641
86778642 # Normalize, if needed
86788643 if normalized :
0 commit comments