Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten.stft.center and decomposition #3880

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

giacs-epic
Copy link
Contributor

@giacs-epic giacs-epic commented Nov 18, 2024

The choice to work with aten.stft.center instead of aten.stft is because the latter doesn't match the signature that gets exposed (see https://pytorch.org/docs/stable/generated/torch.stft.html).

@giacs-epic giacs-epic marked this pull request as ready for review December 2, 2024 13:50
@giacs-epic
Copy link
Contributor Author

CI errors (undefined symbols referenced by LazyNativeFunctions.cpp) are unrelated to the PR content.

@zjgarvey
Copy link
Collaborator

Might need to add tosa xfails if you haven't done so already. IIRC Tosa got added to the CI in your last sync.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general comment before reviewing further:

Is it absolutely necessary to use a loop? The only times I would ever consider using a loop is if there is a necessary loop-carried dependency, but I don't think that is the case here.

Even if there is a nice algorithm for rfft, I don't think decomposing to many rffts in a loop would be more efficient than converting this to, say, a convolution with something like window*exp (if that is possible with the configurations you are trying to support).

Comment on lines +6253 to +6284
// init_freq_tensor = aten.empty.memory_format([batch_dim?, n_freqs,
// n_frames],
// self.dtype, None, None, None, None)
// final_freq_tensor = prim.loop
// n_frames, %true, init(init_freq_tensor)
// {
// ^bb0(frame, freq_tensor):
// begin = frame * hop_length
// end = begin + n_fft
// narrow_length = min(end, signal_len) - begin
// missing = n_fft - narrow_length
// sliced = torch.narrow(self, axis_signal, begin, narrow_length) :
// !torch.vtensor<[batch_dim?,?],f32>
// padded_sliced = aten.pad(sliced, [0, missing], "constant", 0.0) :
// !torch.vtensor<[batch_dim?,?],f32>
// padded_sliced = tensor_static_info_cast(padded_sliced) :
// !torch.vtensor<[batch_dim?,n_fft],f32>
// weighted = aten.mul.Tensor(padded_sliced, window) :
// !torch.vtensor<[batch_dim?,n_fft],f32>
// f = onesidedBool ? aten.fft_rfft : aten.fft_fft
// freq_slice_sq = f(weighted, None, axis_signal) :
// !torch.vtensor<[batch_dim?,n_freqs],f32>
// freq_slice = aten.unsqueeze(freq_slice_sq, axis_frames) :
// !torch.vtensor<[batch_dim?,n_freqs, 1],f32>
// new_freq_tensor = aten.slice_scatter(
// freq_tensor, freq_slice,
// dim=axis_frames, start=frame,
// end=None, step=1
// )
// torch.prim.Loop.condition %true, iter(%new_freq_tensor)
// }
// return final_freq_tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pseudo-IR isn't very helpful as a comment, since you include lit tests for this decomposition.

if (isa<Torch::NoneType>(hopLength.getType())) {
hopLength = rewriter.create<AtenFloordivIntOp>(
loc, n_fft,
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(4)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

There is a builder for Torch::ConstantIntOp which allows passing an int directly, which is a bit easier to read.

Suggested change
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(4)));
rewriter.create<ConstantIntOp>(loc, 4));

Comment on lines +6162 to +6199
Value center = op.getCenter();
bool centerBool;
// TODO: add support for non-constant center and center=True
if (!matchPattern(center, m_TorchConstantBool(&centerBool)))
return rewriter.notifyMatchFailure(op,
"Unsupported: non-constant center");
if (centerBool)
return rewriter.notifyMatchFailure(op, "Unsupported: center=True");

Value normalized = op.getNormalized();
bool normalizedBool;
// TODO: add support for non-constant normalized and normalized=True
if (!matchPattern(normalized, m_TorchConstantBool(&normalizedBool)))
return rewriter.notifyMatchFailure(
op, "Unsupported: non-constant normalized");
if (normalizedBool)
return rewriter.notifyMatchFailure(op, "Unsupported: normalized=True");

bool onesidedBool;
// Default: True for real input and window, False otherwise.
// TODO: add support for non-constant onesided
if (isa<Torch::NoneType>(op.getOnesided().getType())) {
Type dtype = selfType.getDtype();
onesidedBool = !isa<mlir::ComplexType>(dtype);
} else if (!matchPattern(op.getOnesided(),
m_TorchConstantBool(&onesidedBool)))
return rewriter.notifyMatchFailure(op,
"Unsupported: non-constant onesided");

Value returnComplex = op.getReturnComplex();
bool returnComplexBool;
// TODO: add support for non-constant return_complex and return_complex=True
if (!matchPattern(returnComplex, m_TorchConstantBool(&returnComplexBool)))
return rewriter.notifyMatchFailure(
op, "Unsupported: non-constant return_complex");
if (!returnComplex)
return rewriter.notifyMatchFailure(op,
"Unsupported: return_complex=False");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move all of these match failures before the generation of runtime assert ops in the if (hasWindow) block.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants