Skip to content

[mlir] Add filtering callback to GenerateRuntimeVerification pass #150013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hashemthomas1
Copy link
Contributor

Users would be able to create this pass and attach to it a custom callback function to filter out unwanted operations.

Users would be able to create this pass and attach to it a custom
callback function to filter out unwanted operations.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jul 22, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 22, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Thomas Hashem (hashemthomas1)

Changes

Users would be able to create this pass and attach to it a custom callback function to filter out unwanted operations.


Full diff: https://github.com/llvm/llvm-project/pull/150013.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Transforms/Passes.h (+8)
  • (modified) mlir/lib/Transforms/GenerateRuntimeVerification.cpp (+44-3)
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 9cd2ef34e15ea..4749a45e51c1f 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -26,6 +26,7 @@
 namespace mlir {
 
 class GreedyRewriteConfig;
+class RuntimeVerifiableOpInterface;
 
 //===----------------------------------------------------------------------===//
 // Passes
@@ -77,6 +78,13 @@ std::unique_ptr<Pass> createPrintIRPass(const PrintIRPassOptions & = {});
 /// Creates a pass that generates IR to verify ops at runtime.
 std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
 
+/// Create an instance of the generate runtime verification pass, and
+/// use the provided filter function to skip certain verifiable ops.
+/// The default implementation does not filter any ops.
+std::unique_ptr<Pass> createGenerateRuntimeVerificationPass(
+    std::function<bool(RuntimeVerifiableOpInterface)>
+        shouldHandleVerifiableOpFn);
+
 /// Creates a loop invariant code motion pass that hoists loop invariant
 /// instructions out of the loop.
 std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index a40bc2b3272fc..214510ca8ccd4 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -17,16 +17,46 @@ namespace mlir {
 #include "mlir/Transforms/Passes.h.inc"
 } // namespace mlir
 
+#define DEBUG_TYPE "generate-runtime-verification"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 using namespace mlir;
 
+static bool defaultShouldHandleVerifiableOpFn(RuntimeVerifiableOpInterface op) {
+  // By default, all verifiable ops are considered
+  return true;
+}
+
 namespace {
 struct GenerateRuntimeVerificationPass
     : public impl::GenerateRuntimeVerificationBase<
           GenerateRuntimeVerificationPass> {
+
+  GenerateRuntimeVerificationPass();
+  GenerateRuntimeVerificationPass(const GenerateRuntimeVerificationPass &) =
+      default;
+  GenerateRuntimeVerificationPass(
+      std::function<bool(RuntimeVerifiableOpInterface)>
+          shouldHandleVerifiableOpFn);
+
   void runOnOperation() override;
+
+private:
+  // A filter function to select verifiable ops to generate verification for.
+  // If empty, all verifiable ops are considered.
+  std::function<bool(RuntimeVerifiableOpInterface)> shouldHandleVerifiableOpFn;
 };
 } // namespace
 
+GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass()
+    : shouldHandleVerifiableOpFn(defaultShouldHandleVerifiableOpFn) {}
+
+GenerateRuntimeVerificationPass::GenerateRuntimeVerificationPass(
+    std::function<bool(RuntimeVerifiableOpInterface)>
+        shouldHandleVerifiableOpFn)
+    : shouldHandleVerifiableOpFn(std::move(shouldHandleVerifiableOpFn)) {}
+
 void GenerateRuntimeVerificationPass::runOnOperation() {
   // The implementation of the RuntimeVerifiableOpInterface may create ops that
   // can be verified. We don't want to generate verification for IR that
@@ -38,11 +68,22 @@ void GenerateRuntimeVerificationPass::runOnOperation() {
 
   OpBuilder builder(getOperation()->getContext());
   for (RuntimeVerifiableOpInterface verifiableOp : ops) {
-    builder.setInsertionPoint(verifiableOp);
-    verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
-  };
+    if (shouldHandleVerifiableOpFn(verifiableOp)) {
+      builder.setInsertionPoint(verifiableOp);
+      verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+    } else {
+      LDBG("Skipping operation: " << verifiableOp.getOperation());
+    }
+  }
 }
 
 std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
   return std::make_unique<GenerateRuntimeVerificationPass>();
 }
+
+std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass(
+    std::function<bool(RuntimeVerifiableOpInterface)>
+        shouldHandleVerifiableOpFn) {
+  return std::make_unique<GenerateRuntimeVerificationPass>(
+      std::move(shouldHandleVerifiableOpFn));
+}

private:
// A filter function to select verifiable ops to generate verification for.
// If empty, all verifiable ops are considered.
std::function<bool(RuntimeVerifiableOpInterface)> shouldHandleVerifiableOpFn;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph Can we have such non-serializable fields here or should this be a parameter for a new function that is publicly accessible and called from runOnOperation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants