@@ -302,6 +302,75 @@ def forward(self, x):
302302 )
303303 _testing .assert_onnx_program (onnx_program )
304304
305+ def test_aten_stft_1 (self ):
306+ class Model (torch .nn .Module ):
307+ def forward (self , x ):
308+ return torch .ops .aten .stft (x , n_fft = 4 , return_complex = True )
309+
310+ x = torch .randn (4 , 16 , dtype = torch .float32 )
311+
312+ onnx_program = torch .onnx .export (
313+ Model (),
314+ (x ,),
315+ dynamo = True ,
316+ verbose = False ,
317+ )
318+ _testing .assert_onnx_program (onnx_program )
319+
320+ def test_aten_stft_2 (self ):
321+ class Model (torch .nn .Module ):
322+ def forward (self , x ):
323+ return torch .ops .aten .stft (x , n_fft = 4 , return_complex = False )
324+
325+ x = torch .randn (4 , 16 , dtype = torch .float32 )
326+
327+ onnx_program = torch .onnx .export (
328+ Model (),
329+ (x ,),
330+ dynamo = True ,
331+ verbose = False ,
332+ )
333+ _testing .assert_onnx_program (onnx_program )
334+
335+ def test_aten_stft_3 (self ):
336+ class Model (torch .nn .Module ):
337+ def forward (self , x ):
338+ window = torch .ones (16 , dtype = torch .float32 )
339+ return torch .ops .aten .stft (x , n_fft = 16 , window = window , return_complex = False )
340+
341+ x = torch .randn (100 , dtype = torch .float32 )
342+
343+ onnx_program = torch .onnx .export (
344+ Model (),
345+ (x ,),
346+ dynamo = True ,
347+ verbose = False ,
348+ )
349+ _testing .assert_onnx_program (onnx_program )
350+
351+ def test_aten_stft_4 (self ):
352+ class Model (torch .nn .Module ):
353+ def forward (self , x ):
354+ return torch .ops .aten .stft (
355+ x ,
356+ n_fft = 4 ,
357+ hop_length = 1 ,
358+ win_length = 4 ,
359+ center = True ,
360+ onesided = True ,
361+ return_complex = True ,
362+ )
363+
364+ x = torch .randn (4 , 16 , dtype = torch .float32 )
365+
366+ onnx_program = torch .onnx .export (
367+ Model (),
368+ (x ,),
369+ dynamo = True ,
370+ verbose = False ,
371+ )
372+ _testing .assert_onnx_program (onnx_program )
373+
305374
306375if __name__ == "__main__" :
307376 unittest .main ()
0 commit comments