Skip to content

Commit 4738478

Browse files
committed
Revert "[torchlib] Unregister stft, var, var_mean, std, std_mean" (#1867)
This reverts commit 1eef633.
1 parent a1be5c8 commit 4738478

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8548,6 +8548,139 @@ def aten_std_mean_correction(
85488548
return op.Sqrt(var), mean
85498549

85508550

8551+
@torch_op("aten::stft", private=True)
8552+
def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
8553+
signal_rank = Rank(self)
8554+
if signal_rank == 1:
8555+
# Add a batch dimension
8556+
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
8557+
return op.Identity(self), signal_rank
8558+
8559+
8560+
@torch_op("aten::stft", private=True)
8561+
def _center_window_around_zeros_if_needed(
8562+
window: TFloatOrBFloat16, n_fft: int
8563+
) -> TFloatOrBFloat16:
8564+
# first dimension
8565+
n_win = op.Shape(window, start=0, end=1)
8566+
# Center window around zeros if needed (required by ONNX's STFT)
8567+
if n_win < n_fft:
8568+
left = (n_fft - n_win) / 2
8569+
8570+
right = n_fft - left - n_win
8571+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8572+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8573+
8574+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8575+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8576+
right_win = op.CastLike(right_win, window)
8577+
left_win = op.CastLike(left_win, window)
8578+
window = op.Concat(left_win, window, right_win, axis=0)
8579+
return window
8580+
8581+
8582+
@torch_op("aten::stft", private=True)
8583+
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
8584+
left = (n_fft - win_length) / 2
8585+
8586+
right = n_fft - left - win_length
8587+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8588+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8589+
win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))
8590+
8591+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8592+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8593+
window_list = op.Expand(op.Constant(value_ints=[1]), win_length)
8594+
return op.Concat(left_win, window_list, right_win, axis=0)
8595+
8596+
8597+
@torch_op("aten::stft", private=True)
8598+
def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
8599+
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8600+
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
8601+
return window
8602+
8603+
8604+
@torch_op("aten::stft", private=True)
8605+
def _normalize_fft_result(
8606+
signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
8607+
) -> TFloatOrBFloat16:
8608+
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8609+
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
8610+
result = result / sqrt_nfft
8611+
return result
8612+
8613+
8614+
@torch_op("aten::stft", private=True)
8615+
def _aten_stft_onnx(
8616+
signal: TFloatOrBFloat16,
8617+
frame_step_const: INT64,
8618+
window: Union[TFloatOrBFloat16, INT64],
8619+
frame_length_const: INT64,
8620+
signal_rank: INT64,
8621+
onesided: int,
8622+
) -> TFloatOrBFloat16:
8623+
window = op.CastLike(window, signal)
8624+
result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
8625+
result = op.Transpose(result, perm=[0, 2, 1, 3])
8626+
# Remove batch dimension, if needed
8627+
if signal_rank == 1:
8628+
result = op.Squeeze(result, op.Constant(value_ints=[0]))
8629+
return result
8630+
8631+
8632+
@torch_op("aten::stft", trace_only=True)
8633+
def aten_stft(
8634+
self: TFloatOrBFloat16,
8635+
n_fft: int,
8636+
hop_length: Optional[int] = None,
8637+
win_length: Optional[int] = None,
8638+
window: Optional[TFloatOrBFloat16] = None,
8639+
normalized: bool = False,
8640+
onesided: Optional[bool] = None,
8641+
return_complex: Optional[bool] = None,
8642+
) -> TFloatOrBFloat16:
8643+
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
8644+
8645+
# NOTE: regarless of the value of return_complex, we always return a real representation.
8646+
del return_complex
8647+
8648+
# Get STFT sizes
8649+
if hop_length is None:
8650+
# core dump
8651+
# hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8652+
hop_length = n_fft // 4
8653+
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
8654+
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8655+
8656+
# Pre-process input if needed
8657+
self, signal_rank = _add_batch_dimension(self)
8658+
8659+
# Get window and make sure it's the same size as `win_length` or `n_fft`
8660+
if window is not None and window.shape[0] is not None:
8661+
window = _center_window_around_zeros_if_needed(window, n_fft)
8662+
elif window is None:
8663+
if win_length is not None:
8664+
window = _create_window_from_win_length(win_length, n_fft)
8665+
else:
8666+
window = _create_window_from_n_fft(n_fft)
8667+
8668+
if onesided is None or onesided:
8669+
onesided = 1
8670+
else:
8671+
onesided = 0
8672+
# remove batch dimension included
8673+
result = _aten_stft_onnx(
8674+
self, frame_step_const, window, frame_length_const, signal_rank, onesided
8675+
)
8676+
8677+
# Normalize, if needed
8678+
if normalized:
8679+
result = _normalize_fft_result(self, result, n_fft)
8680+
8681+
return result
8682+
8683+
85518684
@torch_op(
85528685
(
85538686
"aten::sub.Tensor",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,14 @@ def _where_input_wrangler(
17601760
TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value),
17611761
TorchLibOpInfo("slice", core_ops.aten_slice),
17621762
TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True),
1763+
TorchLibOpInfo(
1764+
"ops.aten.stft", # Custom from extra_opinfo
1765+
core_ops.aten_stft,
1766+
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
1767+
).xfail(
1768+
dtypes=(torch.float16,),
1769+
reason="RuntimeError: MKL FFT doesn't support tensors of type: Half",
1770+
),
17631771
TorchLibOpInfo(
17641772
"sum",
17651773
core_ops.aten_sum_dim_IntList,

0 commit comments

Comments
 (0)