Skip to content

Commit 76ffeca

Browse files
committed
Add test and fix impl
1 parent eceaa62 commit 76ffeca

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8167,7 +8167,6 @@ def aten_stft(
81678167
# hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
81688168
hop_length = n_fft // 4
81698169
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
8170-
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
81718170

81728171
# Pre-process input if needed
81738172
is_signal_rank1 = len(self.shape) == 1
@@ -8203,7 +8202,7 @@ def aten_stft(
82038202
else:
82048203
onesided = 0
82058204
window = op.CastLike(window, self)
8206-
result = op.STFT(self, frame_step_const, window, frame_length_const, onesided=onesided)
8205+
result = op.STFT(self, frame_step_const, window, n_fft, onesided=onesided)
82078206
result = op.Transpose(result, perm=[0, 2, 1, 3])
82088207
# Remove batch dimension, if needed
82098208
if is_signal_rank1:

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,75 @@ def forward(self, x):
238238
)
239239
_testing.assert_onnx_program(onnx_program)
240240

241+
def test_aten_stft_1(self):
242+
class Model(torch.nn.Module):
243+
def forward(self, x):
244+
return torch.stft(x, n_fft=4, return_complex=True)
245+
246+
x = torch.randn(4, 16, dtype=torch.float32)
247+
248+
onnx_program = torch.onnx.export(
249+
Model(),
250+
(x,),
251+
dynamo=True,
252+
verbose=False,
253+
)
254+
_testing.assert_onnx_program(onnx_program)
255+
256+
def test_aten_stft_2(self):
257+
class Model(torch.nn.Module):
258+
def forward(self, x):
259+
return torch.stft(x, n_fft=4, return_complex=False)
260+
261+
x = torch.randn(4, 16, dtype=torch.float32)
262+
263+
onnx_program = torch.onnx.export(
264+
Model(),
265+
(x,),
266+
dynamo=True,
267+
verbose=False,
268+
)
269+
_testing.assert_onnx_program(onnx_program)
270+
271+
def test_aten_stft_3(self):
272+
class Model(torch.nn.Module):
273+
def forward(self, x):
274+
window = torch.ones(16, dtype=torch.float32)
275+
return torch.stft(x, n_fft=16, window=window, return_complex=False)
276+
277+
x = torch.randn(100, dtype=torch.float32)
278+
279+
onnx_program = torch.onnx.export(
280+
Model(),
281+
(x,),
282+
dynamo=True,
283+
verbose=False,
284+
)
285+
_testing.assert_onnx_program(onnx_program)
286+
287+
def test_aten_stft_4(self):
288+
class Model(torch.nn.Module):
289+
def forward(self, x):
290+
return torch.stft(
291+
x,
292+
n_fft=4,
293+
hop_length=1,
294+
win_length=4,
295+
center=True,
296+
onesided=True,
297+
return_complex=True,
298+
)
299+
300+
x = torch.randn(4, 16, dtype=torch.float32)
301+
302+
onnx_program = torch.onnx.export(
303+
Model(),
304+
(x,),
305+
dynamo=True,
306+
verbose=False,
307+
)
308+
_testing.assert_onnx_program(onnx_program)
309+
241310

242311
if __name__ == "__main__":
243312
unittest.main()

0 commit comments

Comments
 (0)