@@ -9534,13 +9534,13 @@ def forward(self, x):
95349534 [None , 1 , 3 ], # channels
95359535 [16 , 32 ], # n_fft
95369536 [5 , 9 ], # num_frames
9537- [None , 4 , 5 ], # hop_length
9537+ [None , 5 ], # hop_length
95389538 [None , 10 , 8 ], # win_length
95399539 [None , torch .hann_window ], # window
95409540 [False , True ], # center
95419541 [False , True ], # normalized
95429542 [None , False , True ], # onesided
9543- [None , 30 , 40 ], # length
9543+ [None , "shorter" , "larger" ], # length
95449544 [False , True ], # return_complex
95459545 )
95469546 )
@@ -9551,9 +9551,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
95519551 if hop_length is None and win_length is not None :
95529552 pytest .skip ("If win_length is set then we must set hop_length and 0 < hop_length <= win_length" )
95539553
9554+ # Compute input_shape to generate test case
95549555 freq = n_fft // 2 + 1 if onesided else n_fft
95559556 input_shape = (channels , freq , num_frames ) if channels else (freq , num_frames )
95569557
9558+ # If not set,c ompute hop_length for capturing errors
9559+ if hop_length is None :
9560+ hop_length = n_fft // 4
9561+
9562+ if length == "shorter" :
9563+ length = n_fft // 2 + hop_length * (num_frames - 1 )
9564+ elif length == "larger" :
9565+ length = n_fft * 3 // 2 + hop_length * (num_frames - 1 )
9566+
95579567 class ISTFTModel (torch .nn .Module ):
95589568 def forward (self , x ):
95599569 applied_window = window (win_length ) if window and win_length else None
@@ -9573,7 +9583,7 @@ def forward(self, x):
95739583 else :
95749584 return torch .real (x )
95759585
9576- if win_length and center is False :
9586+ if ( center is False and win_length ) or ( center and win_length and length ) :
95779587 # For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
95789588 with pytest .raises (RuntimeError , match = "istft\(.*\) window overlap add min: 1" ):
95799589 TorchBaseTest .run_compare_torch (
@@ -9582,7 +9592,7 @@ def forward(self, x):
95829592 backend = backend ,
95839593 compute_unit = compute_unit
95849594 )
9585- elif length is not None and return_complex is True :
9595+ elif length and return_complex :
95869596 with pytest .raises (ValueError , match = "New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`" ):
95879597 TorchBaseTest .run_compare_torch (
95889598 input_shape ,
0 commit comments