-
Notifications
You must be signed in to change notification settings - Fork 534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add aten.stft.center
and decomposition
#3880
base: main
Are you sure you want to change the base?
Conversation
This reverts commit ae145aa.
CI errors (undefined symbols referenced by |
Might need to add tosa xfails if you haven't done so already. IIRC Tosa got added to the CI in your last sync. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general comment before reviewing further:
Is it absolutely necessary to use a loop? The only times I would ever consider using a loop is if there is a necessary loop-carried dependency, but I don't think that is the case here.
Even if there is a nice algorithm for rfft, I don't think decomposing to many rffts in a loop would be more efficient than converting this to, say, a convolution with something like window*exp
(if that is possible with the configurations you are trying to support).
// init_freq_tensor = aten.empty.memory_format([batch_dim?, n_freqs, | ||
// n_frames], | ||
// self.dtype, None, None, None, None) | ||
// final_freq_tensor = prim.loop | ||
// n_frames, %true, init(init_freq_tensor) | ||
// { | ||
// ^bb0(frame, freq_tensor): | ||
// begin = frame * hop_length | ||
// end = begin + n_fft | ||
// narrow_length = min(end, signal_len) - begin | ||
// missing = n_fft - narrow_length | ||
// sliced = torch.narrow(self, axis_signal, begin, narrow_length) : | ||
// !torch.vtensor<[batch_dim?,?],f32> | ||
// padded_sliced = aten.pad(sliced, [0, missing], "constant", 0.0) : | ||
// !torch.vtensor<[batch_dim?,?],f32> | ||
// padded_sliced = tensor_static_info_cast(padded_sliced) : | ||
// !torch.vtensor<[batch_dim?,n_fft],f32> | ||
// weighted = aten.mul.Tensor(padded_sliced, window) : | ||
// !torch.vtensor<[batch_dim?,n_fft],f32> | ||
// f = onesidedBool ? aten.fft_rfft : aten.fft_fft | ||
// freq_slice_sq = f(weighted, None, axis_signal) : | ||
// !torch.vtensor<[batch_dim?,n_freqs],f32> | ||
// freq_slice = aten.unsqueeze(freq_slice_sq, axis_frames) : | ||
// !torch.vtensor<[batch_dim?,n_freqs, 1],f32> | ||
// new_freq_tensor = aten.slice_scatter( | ||
// freq_tensor, freq_slice, | ||
// dim=axis_frames, start=frame, | ||
// end=None, step=1 | ||
// ) | ||
// torch.prim.Loop.condition %true, iter(%new_freq_tensor) | ||
// } | ||
// return final_freq_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pseudo-IR isn't very helpful as a comment, since you include lit tests for this decomposition.
if (isa<Torch::NoneType>(hopLength.getType())) { | ||
hopLength = rewriter.create<AtenFloordivIntOp>( | ||
loc, n_fft, | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(4))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
There is a builder for Torch::ConstantIntOp
which allows passing an int directly, which is a bit easier to read.
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(4))); | |
rewriter.create<ConstantIntOp>(loc, 4)); |
Value center = op.getCenter(); | ||
bool centerBool; | ||
// TODO: add support for non-constant center and center=True | ||
if (!matchPattern(center, m_TorchConstantBool(¢erBool))) | ||
return rewriter.notifyMatchFailure(op, | ||
"Unsupported: non-constant center"); | ||
if (centerBool) | ||
return rewriter.notifyMatchFailure(op, "Unsupported: center=True"); | ||
|
||
Value normalized = op.getNormalized(); | ||
bool normalizedBool; | ||
// TODO: add support for non-constant normalized and normalized=True | ||
if (!matchPattern(normalized, m_TorchConstantBool(&normalizedBool))) | ||
return rewriter.notifyMatchFailure( | ||
op, "Unsupported: non-constant normalized"); | ||
if (normalizedBool) | ||
return rewriter.notifyMatchFailure(op, "Unsupported: normalized=True"); | ||
|
||
bool onesidedBool; | ||
// Default: True for real input and window, False otherwise. | ||
// TODO: add support for non-constant onesided | ||
if (isa<Torch::NoneType>(op.getOnesided().getType())) { | ||
Type dtype = selfType.getDtype(); | ||
onesidedBool = !isa<mlir::ComplexType>(dtype); | ||
} else if (!matchPattern(op.getOnesided(), | ||
m_TorchConstantBool(&onesidedBool))) | ||
return rewriter.notifyMatchFailure(op, | ||
"Unsupported: non-constant onesided"); | ||
|
||
Value returnComplex = op.getReturnComplex(); | ||
bool returnComplexBool; | ||
// TODO: add support for non-constant return_complex and return_complex=True | ||
if (!matchPattern(returnComplex, m_TorchConstantBool(&returnComplexBool))) | ||
return rewriter.notifyMatchFailure( | ||
op, "Unsupported: non-constant return_complex"); | ||
if (!returnComplex) | ||
return rewriter.notifyMatchFailure(op, | ||
"Unsupported: return_complex=False"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move all of these match failures before the generation of runtime assert ops in the if (hasWindow)
block.
The choice to work with
aten.stft.center
instead ofaten.stft
is because the latter doesn't match the signature that gets exposed (see https://pytorch.org/docs/stable/generated/torch.stft.html).