diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 75f43bf3ea..78eb4398f3 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -16,6 +16,7 @@ "RewriterContext", "MatchingTracer", "MatchStatus", + "RULE_NAME_TAG", ] import onnx @@ -25,6 +26,7 @@ from onnxscript.rewriter import pattern from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus from onnxscript.rewriter._rewrite_rule import ( + RULE_NAME_TAG, RewriterContext, RewriteRule, RewriteRuleClassBase, diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 8964230fe0..9c88aa848e 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -25,6 +25,11 @@ RewriterContext = _tape.Builder +# TODO(rama): Standardize metadata property keys. May be worth standardizing at ONNX level for +# source/producer metadata. + +RULE_NAME_TAG = "pkg.onnxscript.rewriter.rule_name" + @dataclasses.dataclass class ReplacementSubgraph: @@ -719,6 +724,13 @@ def _apply_to_graph_or_function( _ir_utils.display_nodes(delta.new_nodes) print("++++End Replacement Nodes++++") + # Capture rewrite rule name as metadata. + # TODO(rama): This is just a basic version. We may wish to compose "source" metadata + # from multiple rules in future. + if rule.name: + for n in delta.new_nodes: + n.metadata_props[RULE_NAME_TAG] = rule.name + convenience.replace_nodes_and_values( graph_or_function, node, diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index 0a29080b4d..f296b5320c 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -10,6 +10,7 @@ import onnx.parser import onnxscript.optimizer +import onnxscript.rewriter from onnxscript import FLOAT, ir, script from onnxscript import opset17 as op from onnxscript.rewriter import pattern @@ -936,6 +937,44 @@ def add_pattern(op, x, y): match_result = rule_pattern.match(model, model.graph, add_nodes[2]) self.assertFalse(bool(match_result)) + def test_rule_name_metadata(self): + """Test that RewriteRule carries name metadata.""" + + class ReciprocalMulRule(pattern.RewriteRuleClassBase): + def __init__(self, name: str | None = None): + super().__init__(name) + + def pattern(self, op, x, y): + return (1 / x) * y + + def rewrite(self, op, x, y): + return op.Div(y, x) + + @script() + def test_script(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]: + return op.Mul(op.Div(op.Constant(value_float=1.0), x), y) + + rule = ReciprocalMulRule.rule(name="ReciprocalMulToDiv") + model_proto = test_script.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + for node in model.graph: + if node.op_type == "Div": + tag = onnxscript.rewriter.RULE_NAME_TAG + self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulToDiv") + + # By default, the rule name is the class name (if not provided) + rule = ReciprocalMulRule.rule() + model_proto = test_script.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + for node in model.graph: + if node.op_type == "Div": + tag = onnxscript.rewriter.RULE_NAME_TAG + self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulRule") + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self):