Skip to content

Commit c0fd969

Browse files
committed
Add test and fix impl
1 parent 90e4a13 commit c0fd969

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8177,7 +8177,6 @@ def aten_stft(
81778177
# hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
81788178
hop_length = n_fft // 4
81798179
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
8180-
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
81818180

81828181
# Pre-process input if needed
81838182
is_signal_rank1 = len(self.shape) == 1
@@ -8213,7 +8212,7 @@ def aten_stft(
82138212
else:
82148213
onesided = 0
82158214
window = op.CastLike(window, self)
8216-
result = op.STFT(self, frame_step_const, window, frame_length_const, onesided=onesided)
8215+
result = op.STFT(self, frame_step_const, window, n_fft, onesided=onesided)
82178216
result = op.Transpose(result, perm=[0, 2, 1, 3])
82188217
# Remove batch dimension, if needed
82198218
if is_signal_rank1:

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,75 @@ def forward(self, x):
302302
)
303303
_testing.assert_onnx_program(onnx_program)
304304

305+
def test_aten_stft_1(self):
306+
class Model(torch.nn.Module):
307+
def forward(self, x):
308+
return torch.ops.aten.stft(x, n_fft=4, return_complex=True)
309+
310+
x = torch.randn(4, 16, dtype=torch.float32)
311+
312+
onnx_program = torch.onnx.export(
313+
Model(),
314+
(x,),
315+
dynamo=True,
316+
verbose=False,
317+
)
318+
_testing.assert_onnx_program(onnx_program)
319+
320+
def test_aten_stft_2(self):
321+
class Model(torch.nn.Module):
322+
def forward(self, x):
323+
return torch.ops.aten.stft(x, n_fft=4, return_complex=False)
324+
325+
x = torch.randn(4, 16, dtype=torch.float32)
326+
327+
onnx_program = torch.onnx.export(
328+
Model(),
329+
(x,),
330+
dynamo=True,
331+
verbose=False,
332+
)
333+
_testing.assert_onnx_program(onnx_program)
334+
335+
def test_aten_stft_3(self):
336+
class Model(torch.nn.Module):
337+
def forward(self, x):
338+
window = torch.ones(16, dtype=torch.float32)
339+
return torch.ops.aten.stft(x, n_fft=16, window=window, return_complex=False)
340+
341+
x = torch.randn(100, dtype=torch.float32)
342+
343+
onnx_program = torch.onnx.export(
344+
Model(),
345+
(x,),
346+
dynamo=True,
347+
verbose=False,
348+
)
349+
_testing.assert_onnx_program(onnx_program)
350+
351+
def test_aten_stft_4(self):
352+
class Model(torch.nn.Module):
353+
def forward(self, x):
354+
return torch.ops.aten.stft(
355+
x,
356+
n_fft=4,
357+
hop_length=1,
358+
win_length=4,
359+
center=True,
360+
onesided=True,
361+
return_complex=True,
362+
)
363+
364+
x = torch.randn(4, 16, dtype=torch.float32)
365+
366+
onnx_program = torch.onnx.export(
367+
Model(),
368+
(x,),
369+
dynamo=True,
370+
verbose=False,
371+
)
372+
_testing.assert_onnx_program(onnx_program)
373+
305374

306375
if __name__ == "__main__":
307376
unittest.main()

0 commit comments

Comments
 (0)