Skip to content

Commit 90e4a13

Browse files
committed
Fix aten_stft
1 parent 44786f4 commit 90e4a13

File tree

1 file changed

+36
-71
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+36
-71
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 36 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -8128,42 +8128,10 @@ def aten_std_mean_correction(
81288128
return op.Sqrt(var), mean
81298129

81308130

8131-
@torch_op("aten::stft", private=True)
8132-
def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
8133-
signal_rank = Rank(self)
8134-
if signal_rank == 1:
8135-
# Add a batch dimension
8136-
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
8137-
return op.Identity(self), signal_rank
8138-
8139-
8140-
@torch_op("aten::stft", private=True)
8141-
def _center_window_around_zeros_if_needed(
8142-
window: TFloatOrBFloat16, n_fft: int
8143-
) -> TFloatOrBFloat16:
8144-
# first dimension
8145-
n_win = op.Shape(window, start=0, end=1)
8146-
# Center window around zeros if needed (required by ONNX's STFT)
8147-
if n_win < n_fft:
8148-
left = (n_fft - n_win) / 2
8149-
8150-
right = n_fft - left - n_win
8151-
left = op.Reshape(left, op.Constant(value_ints=[1]))
8152-
right = op.Reshape(right, op.Constant(value_ints=[1]))
8153-
8154-
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8155-
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8156-
right_win = op.CastLike(right_win, window)
8157-
left_win = op.CastLike(left_win, window)
8158-
window = op.Concat(left_win, window, right_win, axis=0)
8159-
return window
8160-
8161-
8162-
@torch_op("aten::stft", private=True)
8163-
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
8164-
left = (n_fft - win_length) / 2
8131+
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat:
8132+
left = op.Div(op.Sub(n_fft, win_length), op.Constant(value_ints=[2]))
81658133

8166-
right = n_fft - left - win_length
8134+
right = op.Sub(op.Sub(n_fft, left), win_length)
81678135
left = op.Reshape(left, op.Constant(value_ints=[1]))
81688136
right = op.Reshape(right, op.Constant(value_ints=[1]))
81698137
win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))
@@ -8174,71 +8142,66 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloa
81748142
return op.Concat(left_win, window_list, right_win, axis=0)
81758143

81768144

8177-
@torch_op("aten::stft", private=True)
8178-
def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
8145+
def _create_window_from_n_fft(n_fft: int) -> TFloat:
81798146
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
81808147
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
81818148
return window
81828149

81838150

8184-
@torch_op("aten::stft", private=True)
8185-
def _normalize_fft_result(
8186-
signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
8187-
) -> TFloatOrBFloat16:
8151+
def _normalize_fft_result(signal: TFloat, result: TFloat, n_fft: int) -> TFloat:
81888152
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
81898153
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
8190-
result = result / sqrt_nfft
8191-
return result
8192-
8193-
8194-
@torch_op("aten::stft", private=True)
8195-
def _aten_stft_onnx(
8196-
signal: TFloatOrBFloat16,
8197-
frame_step_const: INT64,
8198-
window: Union[TFloatOrBFloat16, INT64],
8199-
frame_length_const: INT64,
8200-
signal_rank: INT64,
8201-
onesided: int,
8202-
) -> TFloatOrBFloat16:
8203-
window = op.CastLike(window, signal)
8204-
result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
8205-
result = op.Transpose(result, perm=[0, 2, 1, 3])
8206-
# Remove batch dimension, if needed
8207-
if signal_rank == 1:
8208-
result = op.Squeeze(result, op.Constant(value_ints=[0]))
8154+
result = op.Div(result, sqrt_nfft)
82098155
return result
82108156

82118157

82128158
@torch_op("aten::stft", trace_only=True)
82138159
def aten_stft(
8214-
self: TFloatOrBFloat16,
8160+
self: TFloat,
82158161
n_fft: int,
82168162
hop_length: Optional[int] = None,
82178163
win_length: Optional[int] = None,
8218-
window: Optional[TFloatOrBFloat16] = None,
8164+
window: Optional[TFloat] = None,
82198165
normalized: bool = False,
82208166
onesided: Optional[bool] = None,
82218167
return_complex: Optional[bool] = None,
8222-
) -> TFloatOrBFloat16:
8168+
) -> TFloat:
82238169
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
82248170

8225-
# NOTE: regarless of the value of return_complex, we always return a real representation.
8171+
# NOTE: regardless of the value of return_complex, we always return a real representation.
82268172
del return_complex
82278173

82288174
# Get STFT sizes
82298175
if hop_length is None:
82308176
# core dump
8231-
# hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8177+
# hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
82328178
hop_length = n_fft // 4
82338179
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
82348180
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
82358181

82368182
# Pre-process input if needed
8237-
self, signal_rank = _add_batch_dimension(self)
8183+
is_signal_rank1 = len(self.shape) == 1
8184+
if is_signal_rank1:
8185+
# Add a batch dimension
8186+
self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0])))
82388187

82398188
# Get window and make sure it's the same size as `win_length` or `n_fft`
82408189
if window is not None and window.shape[0] is not None:
8241-
window = _center_window_around_zeros_if_needed(window, n_fft)
8190+
# first dimension
8191+
n_win = op.Shape(window, start=0, end=1)
8192+
# Center window around zeros if needed (required by ONNX's STFT)
8193+
if n_win < n_fft:
8194+
left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2]))
8195+
8196+
right = op.Sub(op.Sub(n_fft, left), n_win)
8197+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8198+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8199+
8200+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8201+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8202+
right_win = op.CastLike(right_win, window)
8203+
left_win = op.CastLike(left_win, window)
8204+
window = op.Concat(left_win, window, right_win, axis=0)
82428205
elif window is None:
82438206
if win_length is not None:
82448207
window = _create_window_from_win_length(win_length, n_fft)
@@ -8249,10 +8212,12 @@ def aten_stft(
82498212
onesided = 1
82508213
else:
82518214
onesided = 0
8252-
# remove batch dimension included
8253-
result = _aten_stft_onnx(
8254-
self, frame_step_const, window, frame_length_const, signal_rank, onesided
8255-
)
8215+
window = op.CastLike(window, self)
8216+
result = op.STFT(self, frame_step_const, window, frame_length_const, onesided=onesided)
8217+
result = op.Transpose(result, perm=[0, 2, 1, 3])
8218+
# Remove batch dimension, if needed
8219+
if is_signal_rank1:
8220+
result = op.Squeeze(result, op.Constant(value_ints=[0]))
82568221

82578222
# Normalize, if needed
82588223
if normalized:

0 commit comments

Comments
 (0)