-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir] Add filtering callback to GenerateRuntimeVerification pass #150013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir] Add filtering callback to GenerateRuntimeVerification pass #150013
Conversation
Users would be able to create this pass and attach to it a custom callback function to filter out unwanted operations.
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Thomas Hashem (hashemthomas1) ChangesUsers would be able to create this pass and attach to it a custom callback function to filter out unwanted operations. Full diff: https://github.com/llvm/llvm-project/pull/150013.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9cd2ef34e15ea..4749a45e51c1f 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -26,6 +26,7 @@
namespace mlir {
class GreedyRewriteConfig;
+class RuntimeVerifiableOpInterface;
//===----------------------------------------------------------------------===//
// Passes
@@ -77,6 +78,13 @@ std::unique_ptr<Pass> createPrintIRPass(const PrintIRPassOptions & = {});
/// Creates a pass that generates IR to verify ops at runtime.
std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
+/// Create an instance of the generate runtime verification pass, and
+/// use the provided filter function to skip certain verifiable ops.
+/// The default implementation does not filter any ops.
+std::unique_ptr<Pass> createGenerateRuntimeVerificationPass(
+ std::function<bool(RuntimeVerifiableOpInterface)>
+ shouldHandleVerifiableOpFn);
+
/// Creates a loop invariant code motion pass that hoists loop invariant
/// instructions out of the loop.
std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index a40bc2b3272fc..214510ca8ccd4 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -17,16 +17,46 @@ namespace mlir {
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
+#define DEBUG_TYPE "generate-runtime-verification"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
using namespace mlir;
+static bool defaultShouldHandleVerifiableOpFn(RuntimeVerifiableOpInterface op) {
+ // By default, all verifiable ops are considered
+ return true;
+}
+
namespace {
struct GenerateRuntimeVerificationPass
: public impl::GenerateRuntimeVerificationBase<
GenerateRuntimeVerificationPass> {
+
+ GenerateRuntimeVerificationPass();
+ GenerateRuntimeVerificationPass(const GenerateRuntimeVerificationPass &) =
+ default;
+ GenerateRuntimeVerificationPass(
+ std::function<bool(RuntimeVerifiableOpInterface)>
+ shouldHandleVerifiableOpFn);
+
void runOnOperation() override;
+
+private:
+ // A filter function to select verifiable ops to generate verification for.
+ // If empty, all verifiable ops are considered.
+ std::function<bool(RuntimeVerifiableOpInterface)> shouldHandleVerifiableOpFn;
};
} // namespace
+GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass()
+ : shouldHandleVerifiableOpFn(defaultShouldHandleVerifiableOpFn) {}
+
+GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass(
+ std::function<bool(RuntimeVerifiableOpInterface)>
+ shouldHandleVerifiableOpFn)
+ : shouldHandleVerifiableOpFn(std::move(shouldHandleVerifiableOpFn)) {}
+
void GenerateRuntimeVerificationPass::runOnOperation() {
// The implementation of the RuntimeVerifiableOpInterface may create ops that
// can be verified. We don't want to generate verification for IR that
@@ -38,11 +68,22 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
OpBuilder builder(getOperation()->getContext());
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
- builder.setInsertionPoint(verifiableOp);
- verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
- };
+ if (shouldHandleVerifiableOpFn(verifiableOp)) {
+ builder.setInsertionPoint(verifiableOp);
+ verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+ } else {
+ LDBG("Skipping operation: " << verifiableOp.getOperation());
+ }
+ }
}
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
return std::make_unique<GenerateRuntimeVerificationPass>();
}
+
+std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass(
+ std::function<bool(RuntimeVerifiableOpInterface)>
+ shouldHandleVerifiableOpFn) {
+ return std::make_unique<GenerateRuntimeVerificationPass>(
+ std::move(shouldHandleVerifiableOpFn));
+}
|
private: | ||
// A filter function to select verifiable ops to generate verification for. | ||
// If empty, all verifiable ops are considered. | ||
std::function<bool(RuntimeVerifiableOpInterface)> shouldHandleVerifiableOpFn; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joker-eph Can we have such non-serializable fields here or should this be a parameter for a new function that is publicly accessible and called from runOnOperation
?
Users would be able to create this pass and attach to it a custom callback function to filter out unwanted operations.