Skip to content

Commit 29ba6b9

Browse files
committed
Add test and fix impl
1 parent 111f4ed commit 29ba6b9

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8597,7 +8597,6 @@ def aten_stft(
85978597
# hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
85988598
hop_length = n_fft // 4
85998599
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
8600-
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
86018600

86028601
# Pre-process input if needed
86038602
is_signal_rank1 = len(self.shape) == 1
@@ -8633,7 +8632,7 @@ def aten_stft(
86338632
else:
86348633
onesided = 0
86358634
window = op.CastLike(window, self)
8636-
result = op.STFT(self, frame_step_const, window, frame_length_const, onesided=onesided)
8635+
result = op.STFT(self, frame_step_const, window, n_fft, onesided=onesided)
86378636
result = op.Transpose(result, perm=[0, 2, 1, 3])
86388637
# Remove batch dimension, if needed
86398638
if is_signal_rank1:

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,75 @@ def forward(self, x):
406406
onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False)
407407
_testing.assert_onnx_program(onnx_program)
408408

409+
def test_aten_stft_1(self):
410+
class Model(torch.nn.Module):
411+
def forward(self, x):
412+
return torch.ops.aten.stft(x, n_fft=4, return_complex=True)
413+
414+
x = torch.randn(4, 16, dtype=torch.float32)
415+
416+
onnx_program = torch.onnx.export(
417+
Model(),
418+
(x,),
419+
dynamo=True,
420+
verbose=False,
421+
)
422+
_testing.assert_onnx_program(onnx_program)
423+
424+
def test_aten_stft_2(self):
425+
class Model(torch.nn.Module):
426+
def forward(self, x):
427+
return torch.ops.aten.stft(x, n_fft=4, return_complex=False)
428+
429+
x = torch.randn(4, 16, dtype=torch.float32)
430+
431+
onnx_program = torch.onnx.export(
432+
Model(),
433+
(x,),
434+
dynamo=True,
435+
verbose=False,
436+
)
437+
_testing.assert_onnx_program(onnx_program)
438+
439+
def test_aten_stft_3(self):
440+
class Model(torch.nn.Module):
441+
def forward(self, x):
442+
window = torch.ones(16, dtype=torch.float32)
443+
return torch.ops.aten.stft(x, n_fft=16, window=window, return_complex=False)
444+
445+
x = torch.randn(100, dtype=torch.float32)
446+
447+
onnx_program = torch.onnx.export(
448+
Model(),
449+
(x,),
450+
dynamo=True,
451+
verbose=False,
452+
)
453+
_testing.assert_onnx_program(onnx_program)
454+
455+
def test_aten_stft_4(self):
456+
class Model(torch.nn.Module):
457+
def forward(self, x):
458+
return torch.ops.aten.stft(
459+
x,
460+
n_fft=4,
461+
hop_length=1,
462+
win_length=4,
463+
center=True,
464+
onesided=True,
465+
return_complex=True,
466+
)
467+
468+
x = torch.randn(4, 16, dtype=torch.float32)
469+
470+
onnx_program = torch.onnx.export(
471+
Model(),
472+
(x,),
473+
dynamo=True,
474+
verbose=False,
475+
)
476+
_testing.assert_onnx_program(onnx_program)
477+
409478

410479
if __name__ == "__main__":
411480
unittest.main()

0 commit comments

Comments
 (0)