@@ -406,6 +406,75 @@ def forward(self, x):
406406 onnx_program = torch .onnx .export (model , (x ,), dynamo = True , verbose = False )
407407 _testing .assert_onnx_program (onnx_program )
408408
409+ def test_aten_stft_1 (self ):
410+ class Model (torch .nn .Module ):
411+ def forward (self , x ):
412+ return torch .ops .aten .stft (x , n_fft = 4 , return_complex = True )
413+
414+ x = torch .randn (4 , 16 , dtype = torch .float32 )
415+
416+ onnx_program = torch .onnx .export (
417+ Model (),
418+ (x ,),
419+ dynamo = True ,
420+ verbose = False ,
421+ )
422+ _testing .assert_onnx_program (onnx_program )
423+
424+ def test_aten_stft_2 (self ):
425+ class Model (torch .nn .Module ):
426+ def forward (self , x ):
427+ return torch .ops .aten .stft (x , n_fft = 4 , return_complex = False )
428+
429+ x = torch .randn (4 , 16 , dtype = torch .float32 )
430+
431+ onnx_program = torch .onnx .export (
432+ Model (),
433+ (x ,),
434+ dynamo = True ,
435+ verbose = False ,
436+ )
437+ _testing .assert_onnx_program (onnx_program )
438+
439+ def test_aten_stft_3 (self ):
440+ class Model (torch .nn .Module ):
441+ def forward (self , x ):
442+ window = torch .ones (16 , dtype = torch .float32 )
443+ return torch .ops .aten .stft (x , n_fft = 16 , window = window , return_complex = False )
444+
445+ x = torch .randn (100 , dtype = torch .float32 )
446+
447+ onnx_program = torch .onnx .export (
448+ Model (),
449+ (x ,),
450+ dynamo = True ,
451+ verbose = False ,
452+ )
453+ _testing .assert_onnx_program (onnx_program )
454+
455+ def test_aten_stft_4 (self ):
456+ class Model (torch .nn .Module ):
457+ def forward (self , x ):
458+ return torch .ops .aten .stft (
459+ x ,
460+ n_fft = 4 ,
461+ hop_length = 1 ,
462+ win_length = 4 ,
463+ center = True ,
464+ onesided = True ,
465+ return_complex = True ,
466+ )
467+
468+ x = torch .randn (4 , 16 , dtype = torch .float32 )
469+
470+ onnx_program = torch .onnx .export (
471+ Model (),
472+ (x ,),
473+ dynamo = True ,
474+ verbose = False ,
475+ )
476+ _testing .assert_onnx_program (onnx_program )
477+
409478
410479if __name__ == "__main__" :
411480 unittest .main ()
0 commit comments