Skip to content

Commit 10e541e

Browse files
authored
Implement aten.stft (#2645)
Fixed pytorch/pytorch#147052 ```bash $ python -m pytest tests/function_libs/torch_lib/ops_test.py -k ops_aten_stft ====================================================================================================================================================================================================== test session starts ====================================================================================================================================================================================================== platform linux -- Python 3.13.1, pytest-8.4.1, pluggy-1.6.0 Using --randomly-seed=371864411 rootdir: /home/moatom/github/onnxscript configfile: pyproject.toml plugins: randomly-3.16.0, xdist-3.8.0, subtests-0.14.2, cov-6.2.1, hypothesis-6.138.2 collected 2158 items / 2154 deselected / 4 selected tests/function_libs/torch_lib/ops_test.py s..x [100%] ======================================================================================================================================================================================================= warnings summary ======================================================================================================================================================================================================== onnxscript/converter.py:457: 429 warnings tests/function_libs/torch_lib/ops_test.py: 15 warnings /home/moatom/github/onnxscript/onnxscript/converter.py:457: DeprecationWarning: Expression.__init__ got an unexpected keyword argument 'lineno'. Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15. expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset) onnxscript/converter.py:457: 429 warnings tests/function_libs/torch_lib/ops_test.py: 15 warnings /home/moatom/github/onnxscript/onnxscript/converter.py:457: DeprecationWarning: Expression.__init__ got an unexpected keyword argument 'col_offset'. Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15. expr = ast.Expression(expr, lineno=expr.lineno, col_offset=expr.col_offset) tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_stft_cpu_float32 tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_stft_cpu_float32 tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_stft_cpu_float32 /home/moatom/github/onnxscript/tests/function_libs/torch_lib/ops_test_common.py:329: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword value = np.array(value.cpu()) -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ==================================================================================================================================================================================================== short test summary info ==================================================================================================================================================================================================== SKIPPED [1] tests/function_libs/torch_lib/ops_test.py:101: Traced functions does not have a function proto =================================================================================================================================================================== 2 passed, 1 skipped, 2154 deselected, 1 xfailed, 891 warnings, 7 subtests passed in 4.42s =================================================================================================================================================================== ```
1 parent a1be5c8 commit 10e541e

File tree

3 files changed

+174
-0
lines changed

3 files changed

+174
-0
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8548,6 +8548,103 @@ def aten_std_mean_correction(
85488548
return op.Sqrt(var), mean
85498549

85508550

8551+
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat:
8552+
left = op.Div(op.Sub(n_fft, win_length), op.Constant(value_ints=[2]))
8553+
8554+
right = op.Sub(op.Sub(n_fft, left), win_length)
8555+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8556+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8557+
win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))
8558+
8559+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8560+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8561+
window_list = op.Expand(op.Constant(value_ints=[1]), win_length)
8562+
return op.Concat(left_win, window_list, right_win, axis=0)
8563+
8564+
8565+
def _create_window_from_n_fft(n_fft: int) -> TFloat:
8566+
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8567+
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
8568+
return window
8569+
8570+
8571+
def _normalize_fft_result(signal: TFloat, result: TFloat, n_fft: int) -> TFloat:
8572+
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8573+
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
8574+
result = op.Div(result, sqrt_nfft)
8575+
return result
8576+
8577+
8578+
@torch_op("aten::stft", trace_only=True)
8579+
def aten_stft(
8580+
self: TFloat,
8581+
n_fft: int,
8582+
hop_length: Optional[int] = None,
8583+
win_length: Optional[int] = None,
8584+
window: Optional[TFloat] = None,
8585+
normalized: bool = False,
8586+
onesided: Optional[bool] = None,
8587+
return_complex: Optional[bool] = None,
8588+
) -> TFloat:
8589+
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
8590+
8591+
# NOTE: regardless of the value of return_complex, we always return a real representation.
8592+
del return_complex
8593+
8594+
# Get STFT sizes
8595+
if hop_length is None:
8596+
# core dump
8597+
# hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8598+
hop_length = n_fft // 4
8599+
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
8600+
8601+
# Pre-process input if needed
8602+
is_signal_rank1 = len(self.shape) == 1
8603+
if is_signal_rank1:
8604+
# Add a batch dimension
8605+
self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0])))
8606+
8607+
# Get window and make sure it's the same size as `win_length` or `n_fft`
8608+
if window is not None and window.shape[0] is not None:
8609+
# first dimension
8610+
n_win = op.Shape(window, start=0, end=1)
8611+
# Center window around zeros if needed (required by ONNX's STFT)
8612+
if n_win < n_fft:
8613+
left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2]))
8614+
8615+
right = op.Sub(op.Sub(n_fft, left), n_win)
8616+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8617+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8618+
8619+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8620+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8621+
right_win = op.CastLike(right_win, window)
8622+
left_win = op.CastLike(left_win, window)
8623+
window = op.Concat(left_win, window, right_win, axis=0)
8624+
elif window is None:
8625+
if win_length is not None:
8626+
window = _create_window_from_win_length(win_length, n_fft)
8627+
else:
8628+
window = _create_window_from_n_fft(n_fft)
8629+
8630+
if onesided is None or onesided:
8631+
onesided = 1
8632+
else:
8633+
onesided = 0
8634+
window = op.CastLike(window, self)
8635+
result = op.STFT(self, frame_step_const, window, n_fft, onesided=onesided)
8636+
result = op.Transpose(result, perm=[0, 2, 1, 3])
8637+
# Remove batch dimension, if needed
8638+
if is_signal_rank1:
8639+
result = op.Squeeze(result, op.Constant(value_ints=[0]))
8640+
8641+
# Normalize, if needed
8642+
if normalized:
8643+
result = _normalize_fft_result(self, result, n_fft)
8644+
8645+
return result
8646+
8647+
85518648
@torch_op(
85528649
(
85538650
"aten::sub.Tensor",

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,75 @@ def forward(self, x):
406406
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
407407
_testing.assert_onnx_program(onnx_program)
408408

409+
def test_aten_stft_1(self):
410+
class Model(torch.nn.Module):
411+
def forward(self, x):
412+
return torch.ops.aten.stft(x, n_fft=4, return_complex=True)
413+
414+
x = torch.randn(4, 16, dtype=torch.float32)
415+
416+
onnx_program = torch.onnx.export(
417+
Model(),
418+
(x,),
419+
dynamo=True,
420+
verbose=False,
421+
)
422+
_testing.assert_onnx_program(onnx_program)
423+
424+
def test_aten_stft_2(self):
425+
class Model(torch.nn.Module):
426+
def forward(self, x):
427+
return torch.ops.aten.stft(x, n_fft=4, return_complex=False)
428+
429+
x = torch.randn(4, 16, dtype=torch.float32)
430+
431+
onnx_program = torch.onnx.export(
432+
Model(),
433+
(x,),
434+
dynamo=True,
435+
verbose=False,
436+
)
437+
_testing.assert_onnx_program(onnx_program)
438+
439+
def test_aten_stft_3(self):
440+
class Model(torch.nn.Module):
441+
def forward(self, x):
442+
window = torch.ones(16, dtype=torch.float32)
443+
return torch.ops.aten.stft(x, n_fft=16, window=window, return_complex=False)
444+
445+
x = torch.randn(100, dtype=torch.float32)
446+
447+
onnx_program = torch.onnx.export(
448+
Model(),
449+
(x,),
450+
dynamo=True,
451+
verbose=False,
452+
)
453+
_testing.assert_onnx_program(onnx_program)
454+
455+
def test_aten_stft_4(self):
456+
class Model(torch.nn.Module):
457+
def forward(self, x):
458+
return torch.ops.aten.stft(
459+
x,
460+
n_fft=4,
461+
hop_length=1,
462+
win_length=4,
463+
center=True,
464+
onesided=True,
465+
return_complex=True,
466+
)
467+
468+
x = torch.randn(4, 16, dtype=torch.float32)
469+
470+
onnx_program = torch.onnx.export(
471+
Model(),
472+
(x,),
473+
dynamo=True,
474+
verbose=False,
475+
)
476+
_testing.assert_onnx_program(onnx_program)
477+
409478

410479
if __name__ == "__main__":
411480
unittest.main()

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,14 @@ def _where_input_wrangler(
17601760
TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value),
17611761
TorchLibOpInfo("slice", core_ops.aten_slice),
17621762
TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True),
1763+
TorchLibOpInfo(
1764+
"ops.aten.stft", # Custom from extra_opinfo
1765+
core_ops.aten_stft,
1766+
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
1767+
).xfail(
1768+
dtypes=(torch.float16,),
1769+
reason="RuntimeError: MKL FFT doesn't support tensors of type: Half",
1770+
),
17631771
TorchLibOpInfo(
17641772
"sum",
17651773
core_ops.aten_sum_dim_IntList,

0 commit comments

Comments
 (0)