Skip to content

Commit 111f4ed

Browse files
committed
Fix aten_stft
1 parent 4738478 commit 111f4ed

File tree

1 file changed

+36
-71
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+36
-71
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 36 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -8548,42 +8548,10 @@ def aten_std_mean_correction(
85488548
return op.Sqrt(var), mean
85498549

85508550

8551-
@torch_op("aten::stft", private=True)
8552-
def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
8553-
signal_rank = Rank(self)
8554-
if signal_rank == 1:
8555-
# Add a batch dimension
8556-
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
8557-
return op.Identity(self), signal_rank
8558-
8559-
8560-
@torch_op("aten::stft", private=True)
8561-
def _center_window_around_zeros_if_needed(
8562-
window: TFloatOrBFloat16, n_fft: int
8563-
) -> TFloatOrBFloat16:
8564-
# first dimension
8565-
n_win = op.Shape(window, start=0, end=1)
8566-
# Center window around zeros if needed (required by ONNX's STFT)
8567-
if n_win < n_fft:
8568-
left = (n_fft - n_win) / 2
8569-
8570-
right = n_fft - left - n_win
8571-
left = op.Reshape(left, op.Constant(value_ints=[1]))
8572-
right = op.Reshape(right, op.Constant(value_ints=[1]))
8573-
8574-
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8575-
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8576-
right_win = op.CastLike(right_win, window)
8577-
left_win = op.CastLike(left_win, window)
8578-
window = op.Concat(left_win, window, right_win, axis=0)
8579-
return window
8580-
8581-
8582-
@torch_op("aten::stft", private=True)
8583-
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
8584-
left = (n_fft - win_length) / 2
8551+
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat:
8552+
left = op.Div(op.Sub(n_fft, win_length), op.Constant(value_ints=[2]))
85858553

8586-
right = n_fft - left - win_length
8554+
right = op.Sub(op.Sub(n_fft, left), win_length)
85878555
left = op.Reshape(left, op.Constant(value_ints=[1]))
85888556
right = op.Reshape(right, op.Constant(value_ints=[1]))
85898557
win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))
@@ -8594,71 +8562,66 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloa
85948562
return op.Concat(left_win, window_list, right_win, axis=0)
85958563

85968564

8597-
@torch_op("aten::stft", private=True)
8598-
def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
8565+
def _create_window_from_n_fft(n_fft: int) -> TFloat:
85998566
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
86008567
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
86018568
return window
86028569

86038570

8604-
@torch_op("aten::stft", private=True)
8605-
def _normalize_fft_result(
8606-
signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
8607-
) -> TFloatOrBFloat16:
8571+
def _normalize_fft_result(signal: TFloat, result: TFloat, n_fft: int) -> TFloat:
86088572
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
86098573
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
8610-
result = result / sqrt_nfft
8611-
return result
8612-
8613-
8614-
@torch_op("aten::stft", private=True)
8615-
def _aten_stft_onnx(
8616-
signal: TFloatOrBFloat16,
8617-
frame_step_const: INT64,
8618-
window: Union[TFloatOrBFloat16, INT64],
8619-
frame_length_const: INT64,
8620-
signal_rank: INT64,
8621-
onesided: int,
8622-
) -> TFloatOrBFloat16:
8623-
window = op.CastLike(window, signal)
8624-
result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
8625-
result = op.Transpose(result, perm=[0, 2, 1, 3])
8626-
# Remove batch dimension, if needed
8627-
if signal_rank == 1:
8628-
result = op.Squeeze(result, op.Constant(value_ints=[0]))
8574+
result = op.Div(result, sqrt_nfft)
86298575
return result
86308576

86318577

86328578
@torch_op("aten::stft", trace_only=True)
86338579
def aten_stft(
8634-
self: TFloatOrBFloat16,
8580+
self: TFloat,
86358581
n_fft: int,
86368582
hop_length: Optional[int] = None,
86378583
win_length: Optional[int] = None,
8638-
window: Optional[TFloatOrBFloat16] = None,
8584+
window: Optional[TFloat] = None,
86398585
normalized: bool = False,
86408586
onesided: Optional[bool] = None,
86418587
return_complex: Optional[bool] = None,
8642-
) -> TFloatOrBFloat16:
8588+
) -> TFloat:
86438589
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
86448590

8645-
# NOTE: regarless of the value of return_complex, we always return a real representation.
8591+
# NOTE: regardless of the value of return_complex, we always return a real representation.
86468592
del return_complex
86478593

86488594
# Get STFT sizes
86498595
if hop_length is None:
86508596
# core dump
8651-
# hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8597+
# hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
86528598
hop_length = n_fft // 4
86538599
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
86548600
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
86558601

86568602
# Pre-process input if needed
8657-
self, signal_rank = _add_batch_dimension(self)
8603+
is_signal_rank1 = len(self.shape) == 1
8604+
if is_signal_rank1:
8605+
# Add a batch dimension
8606+
self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0])))
86588607

86598608
# Get window and make sure it's the same size as `win_length` or `n_fft`
86608609
if window is not None and window.shape[0] is not None:
8661-
window = _center_window_around_zeros_if_needed(window, n_fft)
8610+
# first dimension
8611+
n_win = op.Shape(window, start=0, end=1)
8612+
# Center window around zeros if needed (required by ONNX's STFT)
8613+
if n_win < n_fft:
8614+
left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2]))
8615+
8616+
right = op.Sub(op.Sub(n_fft, left), n_win)
8617+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8618+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8619+
8620+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8621+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8622+
right_win = op.CastLike(right_win, window)
8623+
left_win = op.CastLike(left_win, window)
8624+
window = op.Concat(left_win, window, right_win, axis=0)
86628625
elif window is None:
86638626
if win_length is not None:
86648627
window = _create_window_from_win_length(win_length, n_fft)
@@ -8669,10 +8632,12 @@ def aten_stft(
86698632
onesided = 1
86708633
else:
86718634
onesided = 0
8672-
# remove batch dimension included
8673-
result = _aten_stft_onnx(
8674-
self, frame_step_const, window, frame_length_const, signal_rank, onesided
8675-
)
8635+
window = op.CastLike(window, self)
8636+
result = op.STFT(self, frame_step_const, window, frame_length_const, onesided=onesided)
8637+
result = op.Transpose(result, perm=[0, 2, 1, 3])
8638+
# Remove batch dimension, if needed
8639+
if is_signal_rank1:
8640+
result = op.Squeeze(result, op.Constant(value_ints=[0]))
86768641

86778642
# Normalize, if needed
86788643
if normalized:

0 commit comments

Comments
 (0)