From 9cefe72432beed62ef1a052153fdc7d17ce6e996 Mon Sep 17 00:00:00 2001 From: "Cyne Jarvis J. Zarceno" <165563006+Kuhai9801@users.noreply.github.com> Date: Wed, 17 Jun 2026 19:22:49 +0800 Subject: [PATCH 01/12] Topologically order poly auxiliary writes in compute Fixes project-llzk/llzk-lib#551. --- .../unreleased/poly-aux-write-order.yaml | 2 + lib/Transforms/LLZKPolyLoweringPass.cpp | 122 +++++++++++++++++- .../poly_lowering_aux_write_order.llzk | 37 ++++++ .../poly_lowering_composite_roots.llzk | 3 +- .../PolyLowering/poly_lowering_pass_deg3.llzk | 5 +- 5 files changed, 160 insertions(+), 9 deletions(-) create mode 100644 changelogs/unreleased/poly-aux-write-order.yaml create mode 100644 test/Transforms/PolyLowering/poly_lowering_aux_write_order.llzk diff --git a/changelogs/unreleased/poly-aux-write-order.yaml b/changelogs/unreleased/poly-aux-write-order.yaml new file mode 100644 index 0000000000..c8a2ee061d --- /dev/null +++ b/changelogs/unreleased/poly-aux-write-order.yaml @@ -0,0 +1,2 @@ +fixed: + - Topologically order polynomial-lowering auxiliary writes in compute functions diff --git a/lib/Transforms/LLZKPolyLoweringPass.cpp b/lib/Transforms/LLZKPolyLoweringPass.cpp index 2784686fef..35bb9fc175 100644 --- a/lib/Transforms/LLZKPolyLoweringPass.cpp +++ b/lib/Transforms/LLZKPolyLoweringPass.cpp @@ -24,7 +24,10 @@ #include #include +#include +#include #include +#include #include #include @@ -51,6 +54,7 @@ namespace { struct AuxAssignment { std::string auxMemberName; Value computedValue; + Value auxValue; }; class PassImpl : public llzk::impl::PolyLoweringPassBase { @@ -66,6 +70,105 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { }); } + void addAuxDependency( + unsigned dep, unsigned owner, DenseSet &seenDeps, SmallVectorImpl &deps + ) const { + if (dep == owner) { + return; + } + if (seenDeps.insert(dep).second) { + deps.push_back(dep); + } + } + + void collectAuxDependencies( + Value val, unsigned owner, const DenseMap &auxValueToIndex, + const llvm::StringMap &auxNameToIndex, DenseSet &visitedValues, + DenseSet &seenDeps, SmallVectorImpl &deps + ) const { + if (!val || !visitedValues.insert(val).second) { + return; + } + + if (auto it = auxValueToIndex.find(val); it != auxValueToIndex.end()) { + addAuxDependency(it->second, owner, seenDeps, deps); + } + + if (auto readOp = val.getDefiningOp()) { + auto it = auxNameToIndex.find(readOp.getMemberName()); + if (it != auxNameToIndex.end()) { + addAuxDependency(it->second, owner, seenDeps, deps); + } + } + + Operation *defOp = val.getDefiningOp(); + if (!defOp) { + return; + } + + for (Value operand : defOp->getOperands()) { + collectAuxDependencies( + operand, owner, auxValueToIndex, auxNameToIndex, visitedValues, seenDeps, deps + ); + } + } + + LogicalResult visitAuxAssignment( + unsigned idx, ArrayRef> deps, SmallVectorImpl &visitState, + SmallVectorImpl &ordered, ArrayRef auxAssignments + ) const { + if (visitState[idx] == 2) { + return success(); + } + if (visitState[idx] == 1) { + return emitError(auxAssignments[idx].computedValue.getLoc()) + << "poly lowering generated cyclic auxiliary dependency involving @" + << auxAssignments[idx].auxMemberName; + } + + visitState[idx] = 1; + for (unsigned dep : deps[idx]) { + if (failed(visitAuxAssignment(dep, deps, visitState, ordered, auxAssignments))) { + return failure(); + } + } + visitState[idx] = 2; + ordered.push_back(idx); + return success(); + } + + LogicalResult orderAuxAssignments( + ArrayRef auxAssignments, SmallVectorImpl &ordered + ) const { + DenseMap auxValueToIndex; + llvm::StringMap auxNameToIndex; + auxValueToIndex.reserve(auxAssignments.size()); + for (auto [idx, assign] : llvm::enumerate(auxAssignments)) { + if (assign.auxValue) { + auxValueToIndex[assign.auxValue] = idx; + } + auxNameToIndex[assign.auxMemberName] = idx; + } + + SmallVector> deps(auxAssignments.size()); + for (auto [idx, assign] : llvm::enumerate(auxAssignments)) { + DenseSet visitedValues; + DenseSet seenDeps; + collectAuxDependencies( + assign.computedValue, idx, auxValueToIndex, auxNameToIndex, visitedValues, seenDeps, + deps[idx] + ); + } + + SmallVector visitState(auxAssignments.size(), 0); + for (unsigned idx = 0, e = auxAssignments.size(); idx < e; ++idx) { + if (failed(visitAuxAssignment(idx, deps, visitState, ordered, auxAssignments))) { + return failure(); + } + } + return success(); + } + // Recursively compute degree of FeltOps SSA values unsigned getDegree(Value val, DenseMap &memo) { if (auto it = memo.find(val); it != memo.end()) { @@ -183,7 +286,7 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { auto auxVal = builder.create( lhs.getLoc(), lhs.getType(), selfVal, auxMember.getNameAttr() ); - auxAssignments.push_back({auxName, lhs}); + auxAssignments.push_back({auxName, lhs, auxVal}); Location loc = builder.getFusedLoc({auxVal.getLoc(), lhs.getLoc()}); auto eqOp = builder.create(loc, auxVal, lhs); @@ -216,7 +319,7 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { // Emit constraint: auxVal == toFactor Location loc = builder.getFusedLoc({auxVal.getLoc(), toFactor.getLoc()}); auto eqOp = builder.create(loc, auxVal, toFactor); - auxAssignments.push_back({auxName, toFactor}); + auxAssignments.push_back({auxName, toFactor, auxVal}); // Update memoization rewrites[toFactor] = auxVal; degreeMemo[auxVal] = 1; // stays same @@ -275,7 +378,7 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { Location loc = builder.getFusedLoc({auxVal.getLoc(), loweredVal.getLoc()}); builder.create(loc, auxVal, loweredVal); - auxAssignments.push_back({auxName, loweredVal}); + auxAssignments.push_back({auxName, loweredVal, auxVal}); degreeMemo[auxVal] = 1; rewrites[loweredVal] = auxVal; @@ -460,13 +563,24 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator()); Value selfVal = computeFunc.getSelfValueFromCompute(); - for (const auto &assign : auxAssignments) { + SmallVector orderedAuxAssignments; + orderedAuxAssignments.reserve(auxAssignments.size()); + if (failed(orderAuxAssignments(auxAssignments, orderedAuxAssignments))) { + signalPassFailure(); + return; + } + + for (unsigned assignIdx : orderedAuxAssignments) { + const auto &assign = auxAssignments[assignIdx]; Value rebuiltExpr = rebuildExprInCompute(assign.computedValue, computeFunc, builder, rebuildMemo); builder.create( assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxMemberName), rebuiltExpr ); + if (assign.auxValue) { + rebuildMemo[assign.auxValue] = rebuiltExpr; + } } }); } diff --git a/test/Transforms/PolyLowering/poly_lowering_aux_write_order.llzk b/test/Transforms/PolyLowering/poly_lowering_aux_write_order.llzk new file mode 100644 index 0000000000..96d8e1eb5f --- /dev/null +++ b/test/Transforms/PolyLowering/poly_lowering_aux_write_order.llzk @@ -0,0 +1,37 @@ +// RUN: llzk-opt -I %S -split-input-file -llzk-full-poly-lowering --verify-each %s | FileCheck --enable-var-scope %s + +module attributes {llzk.lang, llzk.main = !struct.type<@CmpConstraint>} { + struct.def @CmpConstraint { + struct.member @out : !felt.type {llzk.pub} + + function.def @compute( + %a: !felt.type {function.arg_name = "a"}, + %b: !felt.type {function.arg_name = "b"} + ) -> !struct.type<@CmpConstraint> { + %self = struct.new : !struct.type<@CmpConstraint> + struct.writem %self[@out] = %a : !struct.type<@CmpConstraint>, !felt.type + function.return %self : !struct.type<@CmpConstraint> + } + + function.def @constrain(%self: !struct.type<@CmpConstraint>, %a: !felt.type, %b: !felt.type) { + %z = felt.mul %a, %b + %zz = felt.mul %z, %z + %za = felt.mul %z, %a + %za_sq = felt.mul %za, %za + constrain.eq %za_sq, %z : !felt.type + constrain.eq %zz, %za : !felt.type + function.return + } + } +} + +// CHECK-LABEL: struct.def @CmpConstraint +// CHECK-LABEL: function.def @compute +// CHECK: %[[SELF:.*]] = struct.new : <@CmpConstraint> +// CHECK-NOT: struct.readm %[[SELF]][@__llzk_poly_lowering_pass_aux_member_ +// CHECK: %[[Z:.*]] = felt.mul %{{.*}}, %{{.*}} : !felt.type, !felt.type +// CHECK: struct.writem %[[SELF]][@__llzk_poly_lowering_pass_aux_member_{{[0-9]+}}] = %[[Z]] : <@CmpConstraint>, !felt.type +// CHECK: %[[ZA:.*]] = felt.mul %[[Z]], %{{.*}} : !felt.type, !felt.type +// CHECK: struct.writem %[[SELF]][@__llzk_poly_lowering_pass_aux_member_{{[0-9]+}}] = %[[ZA]] : <@CmpConstraint>, !felt.type +// CHECK-NOT: struct.readm %[[SELF]][@__llzk_poly_lowering_pass_aux_member_ +// CHECK: function.return %[[SELF]] : !struct.type<@CmpConstraint> diff --git a/test/Transforms/PolyLowering/poly_lowering_composite_roots.llzk b/test/Transforms/PolyLowering/poly_lowering_composite_roots.llzk index fb2aae76d5..7189339aa8 100644 --- a/test/Transforms/PolyLowering/poly_lowering_composite_roots.llzk +++ b/test/Transforms/PolyLowering/poly_lowering_composite_roots.llzk @@ -167,8 +167,7 @@ module attributes {llzk.lang} { // CHECK: %[[C_SELF:.*]] = struct.new : <@CompositeCallArg> // CHECK: %[[C_AB:.*]] = felt.mul %[[C_A]], %[[C_B]] : !felt.type, !felt.type // CHECK: struct.writem %[[C_SELF]][@__llzk_poly_lowering_pass_aux_member_0] = %[[C_AB]] : <@CompositeCallArg>, !felt.type -// CHECK: %[[C_AUX0:.*]] = struct.readm %[[C_SELF]][@__llzk_poly_lowering_pass_aux_member_0] : <@CompositeCallArg>, !felt.type -// CHECK: %[[C_ABC:.*]] = felt.mul %[[C_AUX0]], %[[C_C]] : !felt.type, !felt.type +// CHECK: %[[C_ABC:.*]] = felt.mul %[[C_AB]], %[[C_C]] : !felt.type, !felt.type // CHECK: %[[C_ARG:.*]] = felt.add %[[C_ABC]], %[[C_D]] : !felt.type, !felt.type // CHECK: struct.writem %[[C_SELF]][@__llzk_poly_lowering_pass_aux_member_1] = %[[C_ARG]] : <@CompositeCallArg>, !felt.type // CHECK: function.return %[[C_SELF]] : !struct.type<@CompositeCallArg> diff --git a/test/Transforms/PolyLowering/poly_lowering_pass_deg3.llzk b/test/Transforms/PolyLowering/poly_lowering_pass_deg3.llzk index d43c2e1241..afe8eff7e4 100644 --- a/test/Transforms/PolyLowering/poly_lowering_pass_deg3.llzk +++ b/test/Transforms/PolyLowering/poly_lowering_pass_deg3.llzk @@ -23,11 +23,10 @@ module attributes {llzk.lang} { // CHECK-LABEL: struct.def @CmpConstraint { // CHECK: function.def @compute(%[[VAL_0:.*]]: !felt.type, %[[VAL_1:.*]]: !felt.type) -> !struct.type<@CmpConstraint> attributes {function.allow_witness} { // CHECK: %[[VAL_2:.*]] = struct.new : <@CmpConstraint> -// CHECK: %[[VAL_3:.*]] = struct.readm %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_1] : <@CmpConstraint>, !felt.type +// CHECK: %[[VAL_3:.*]] = felt.mul %[[VAL_0]], %[[VAL_1]] : !felt.type, !felt.type +// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_1] = %[[VAL_3]] : <@CmpConstraint>, !felt.type // CHECK: %[[VAL_4:.*]] = felt.mul %[[VAL_3]], %[[VAL_0]] : !felt.type, !felt.type // CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_0] = %[[VAL_4]] : <@CmpConstraint>, !felt.type -// CHECK: %[[VAL_5:.*]] = felt.mul %[[VAL_0]], %[[VAL_1]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_1] = %[[VAL_5]] : <@CmpConstraint>, !felt.type // CHECK: function.return %[[VAL_2]] : !struct.type<@CmpConstraint> // CHECK: } // CHECK: function.def @constrain(%[[VAL_6:.*]]: !struct.type<@CmpConstraint>, %[[VAL_7:.*]]: !felt.type, %[[VAL_8:.*]]: !felt.type) attributes {function.allow_constraint} { From be67e8d7f3ef52334c25104468d4a4482305da7b Mon Sep 17 00:00:00 2001 From: "Cyne Jarvis J. Zarceno" <165563006+Kuhai9801@users.noreply.github.com> Date: Wed, 17 Jun 2026 21:41:41 +0800 Subject: [PATCH 02/12] Use existing aux write order regression --- .../poly_lowering_aux_write_order.llzk | 37 ------------------- 1 file changed, 37 deletions(-) delete mode 100644 test/Transforms/PolyLowering/poly_lowering_aux_write_order.llzk diff --git a/test/Transforms/PolyLowering/poly_lowering_aux_write_order.llzk b/test/Transforms/PolyLowering/poly_lowering_aux_write_order.llzk deleted file mode 100644 index 96d8e1eb5f..0000000000 --- a/test/Transforms/PolyLowering/poly_lowering_aux_write_order.llzk +++ /dev/null @@ -1,37 +0,0 @@ -// RUN: llzk-opt -I %S -split-input-file -llzk-full-poly-lowering --verify-each %s | FileCheck --enable-var-scope %s - -module attributes {llzk.lang, llzk.main = !struct.type<@CmpConstraint>} { - struct.def @CmpConstraint { - struct.member @out : !felt.type {llzk.pub} - - function.def @compute( - %a: !felt.type {function.arg_name = "a"}, - %b: !felt.type {function.arg_name = "b"} - ) -> !struct.type<@CmpConstraint> { - %self = struct.new : !struct.type<@CmpConstraint> - struct.writem %self[@out] = %a : !struct.type<@CmpConstraint>, !felt.type - function.return %self : !struct.type<@CmpConstraint> - } - - function.def @constrain(%self: !struct.type<@CmpConstraint>, %a: !felt.type, %b: !felt.type) { - %z = felt.mul %a, %b - %zz = felt.mul %z, %z - %za = felt.mul %z, %a - %za_sq = felt.mul %za, %za - constrain.eq %za_sq, %z : !felt.type - constrain.eq %zz, %za : !felt.type - function.return - } - } -} - -// CHECK-LABEL: struct.def @CmpConstraint -// CHECK-LABEL: function.def @compute -// CHECK: %[[SELF:.*]] = struct.new : <@CmpConstraint> -// CHECK-NOT: struct.readm %[[SELF]][@__llzk_poly_lowering_pass_aux_member_ -// CHECK: %[[Z:.*]] = felt.mul %{{.*}}, %{{.*}} : !felt.type, !felt.type -// CHECK: struct.writem %[[SELF]][@__llzk_poly_lowering_pass_aux_member_{{[0-9]+}}] = %[[Z]] : <@CmpConstraint>, !felt.type -// CHECK: %[[ZA:.*]] = felt.mul %[[Z]], %{{.*}} : !felt.type, !felt.type -// CHECK: struct.writem %[[SELF]][@__llzk_poly_lowering_pass_aux_member_{{[0-9]+}}] = %[[ZA]] : <@CmpConstraint>, !felt.type -// CHECK-NOT: struct.readm %[[SELF]][@__llzk_poly_lowering_pass_aux_member_ -// CHECK: function.return %[[SELF]] : !struct.type<@CmpConstraint> From b4b4e27bc0573db8fd7a6bf51375d9082a9d7803 Mon Sep 17 00:00:00 2001 From: "Cyne Jarvis J. Zarceno" <165563006+Kuhai9801@users.noreply.github.com> Date: Wed, 17 Jun 2026 21:59:17 +0800 Subject: [PATCH 03/12] Match changelog entry to branch name --- .../{poly-aux-write-order.yaml => fix__poly-aux-write-order.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename changelogs/unreleased/{poly-aux-write-order.yaml => fix__poly-aux-write-order.yaml} (100%) diff --git a/changelogs/unreleased/poly-aux-write-order.yaml b/changelogs/unreleased/fix__poly-aux-write-order.yaml similarity index 100% rename from changelogs/unreleased/poly-aux-write-order.yaml rename to changelogs/unreleased/fix__poly-aux-write-order.yaml From 43464d6722a8e3f6d31afaa7df030588c1c1faf7 Mon Sep 17 00:00:00 2001 From: sgtpepper <165563006+Kuhai9801@users.noreply.github.com> Date: Thu, 18 Jun 2026 00:03:34 +0800 Subject: [PATCH 04/12] Emit typed auxiliary members during lowering (#556) --- .../lib/r1cs/Transforms/R1CSLoweringPass.cpp | 6 +-- .../fix__poly-typed-aux-members.yaml | 4 ++ include/llzk/Transforms/LLZKLoweringUtils.h | 4 +- lib/Transforms/LLZKLoweringUtils.cpp | 8 ++-- lib/Transforms/LLZKPolyLoweringPass.cpp | 6 +-- .../poly_lowering_typed_aux_member.llzk | 39 +++++++++++++++++++ .../r1cs_lowering_typed_aux_member.llzk | 38 ++++++++++++++++++ 7 files changed, 94 insertions(+), 11 deletions(-) create mode 100644 changelogs/unreleased/fix__poly-typed-aux-members.yaml create mode 100644 test/Transforms/PolyLowering/poly_lowering_typed_aux_member.llzk create mode 100644 test/Transforms/R1CSLowering/r1cs_lowering_typed_aux_member.llzk diff --git a/backends/r1cs/lib/r1cs/Transforms/R1CSLoweringPass.cpp b/backends/r1cs/lib/r1cs/Transforms/R1CSLoweringPass.cpp index 87ba831ed3..880b02f398 100644 --- a/backends/r1cs/lib/r1cs/Transforms/R1CSLoweringPass.cpp +++ b/backends/r1cs/lib/r1cs/Transforms/R1CSLoweringPass.cpp @@ -340,7 +340,7 @@ class PassImpl : public r1cs::impl::R1CSLoweringPassBase { if (degLhs == 2 && degRhs == 2) { builder.setInsertionPoint(op); std::string auxName = R1CS_AUXILIARY_MEMBER_PREFIX + std::to_string(auxCounter++); - MemberDefOp auxMember = addAuxMember(structDef, auxName); + MemberDefOp auxMember = addAuxMember(structDef, auxName, val.getType()); Value aux = builder.create( val.getLoc(), val.getType(), constrainFunc.getSelfValueFromConstrain(), auxMember.getNameAttr() @@ -532,7 +532,7 @@ class PassImpl : public r1cs::impl::R1CSLoweringPassBase { // Entire linear combination was zero result = builder.create( loc, r1cs::LinearType::get(builder.getContext()), - r1cs::FeltAttr::get(builder.getContext(), 0) + r1cs::FeltAttr::get(builder.getContext(), toAPSInt(lc.constant)) ); } @@ -673,7 +673,7 @@ class PassImpl : public r1cs::impl::R1CSLoweringPassBase { if (degLhs == 2 && degRhs == 2) { builder.setInsertionPoint(eqOp); std::string auxName = R1CS_AUXILIARY_MEMBER_PREFIX + std::to_string(auxCounter++); - MemberDefOp auxMember = addAuxMember(structDef, auxName); + MemberDefOp auxMember = addAuxMember(structDef, auxName, lhs.getType()); Value aux = builder.create( eqOp.getLoc(), lhs.getType(), constrainFunc.getSelfValueFromConstrain(), auxMember.getNameAttr() diff --git a/changelogs/unreleased/fix__poly-typed-aux-members.yaml b/changelogs/unreleased/fix__poly-typed-aux-members.yaml new file mode 100644 index 0000000000..d4561d7685 --- /dev/null +++ b/changelogs/unreleased/fix__poly-typed-aux-members.yaml @@ -0,0 +1,4 @@ +fixed: + - Emit polynomial-lowering auxiliary members with the exact type of the materialized expression + - Emit all R1CS-lowering auxiliary members with the exact type of the materialized expression + - Emit synthesized zero R1CS linear-combination constants with a printable integer width diff --git a/include/llzk/Transforms/LLZKLoweringUtils.h b/include/llzk/Transforms/LLZKLoweringUtils.h index 3a346dbd23..e7af800873 100644 --- a/include/llzk/Transforms/LLZKLoweringUtils.h +++ b/include/llzk/Transforms/LLZKLoweringUtils.h @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -39,7 +40,8 @@ checkForAuxMemberConflicts(component::StructDefOp structDef, llvm::StringRef aux mlir::LogicalResult checkConstrainBodyIsStraightLine(function::FuncDefOp constrainFunc, llvm::StringRef passName); -component::MemberDefOp addAuxMember(component::StructDefOp structDef, llvm::StringRef name); +component::MemberDefOp +addAuxMember(component::StructDefOp structDef, llvm::StringRef name, mlir::Type type); unsigned getFeltDegree(mlir::Value val, llvm::DenseMap &memo); diff --git a/lib/Transforms/LLZKLoweringUtils.cpp b/lib/Transforms/LLZKLoweringUtils.cpp index 34effc94da..ee206bcd11 100644 --- a/lib/Transforms/LLZKLoweringUtils.cpp +++ b/lib/Transforms/LLZKLoweringUtils.cpp @@ -148,12 +148,12 @@ void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp) { } } -MemberDefOp addAuxMember(StructDefOp structDef, StringRef name) { +MemberDefOp addAuxMember(StructDefOp structDef, StringRef name, Type type) { + assert(type && "auxiliary member type must be non-null"); + OpBuilder builder(structDef); builder.setInsertionPointToEnd(structDef.getBody()); - return builder.create( - structDef.getLoc(), builder.getStringAttr(name), builder.getType() - ); + return builder.create(structDef.getLoc(), builder.getStringAttr(name), type); } unsigned getFeltDegree(Value val, DenseMap &memo) { diff --git a/lib/Transforms/LLZKPolyLoweringPass.cpp b/lib/Transforms/LLZKPolyLoweringPass.cpp index 2784686fef..ca8e84d49b 100644 --- a/lib/Transforms/LLZKPolyLoweringPass.cpp +++ b/lib/Transforms/LLZKPolyLoweringPass.cpp @@ -178,7 +178,7 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { // Optimization: If lhs == rhs, factor it only once if (lhs == rhs && eraseMul) { std::string auxName = AUXILIARY_MEMBER_PREFIX + std::to_string(this->auxCounter++); - MemberDefOp auxMember = addAuxMember(structDef, auxName); + MemberDefOp auxMember = addAuxMember(structDef, auxName, lhs.getType()); auto auxVal = builder.create( lhs.getLoc(), lhs.getType(), selfVal, auxMember.getNameAttr() @@ -206,7 +206,7 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { // Create auxiliary member for toFactor std::string auxName = AUXILIARY_MEMBER_PREFIX + std::to_string(this->auxCounter++); - MemberDefOp auxMember = addAuxMember(structDef, auxName); + MemberDefOp auxMember = addAuxMember(structDef, auxName, toFactor.getType()); // Read back as MemberReadOp (new SSA value) auto auxVal = builder.create( @@ -265,7 +265,7 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { // Callees only receive SSA values, not the caller expression tree, so nonlinear // call arguments must be represented by an auxiliary member read. std::string auxName = AUXILIARY_MEMBER_PREFIX + std::to_string(this->auxCounter++); - MemberDefOp auxMember = addAuxMember(structDef, auxName); + MemberDefOp auxMember = addAuxMember(structDef, auxName, loweredVal.getType()); OpBuilder builder(callOp); Value selfVal = constrainFunc.getSelfValueFromConstrain(); diff --git a/test/Transforms/PolyLowering/poly_lowering_typed_aux_member.llzk b/test/Transforms/PolyLowering/poly_lowering_typed_aux_member.llzk new file mode 100644 index 0000000000..8d3dc05b3d --- /dev/null +++ b/test/Transforms/PolyLowering/poly_lowering_typed_aux_member.llzk @@ -0,0 +1,39 @@ +// RUN: llzk-opt -I %S -split-input-file -llzk-poly-lowering-pass="max-degree=2" --verify-each %s | FileCheck --enable-var-scope %s + +module attributes {llzk.lang} { + struct.def @TypedAux { + struct.member @out : !felt.type<"babybear"> + + function.def @compute( + %a: !felt.type<"babybear">, + %b: !felt.type<"babybear">, + %c: !felt.type<"babybear"> + ) -> !struct.type<@TypedAux> { + %self = struct.new : !struct.type<@TypedAux> + function.return %self : !struct.type<@TypedAux> + } + + function.def @constrain( + %self: !struct.type<@TypedAux>, + %a: !felt.type<"babybear">, + %b: !felt.type<"babybear">, + %c: !felt.type<"babybear"> + ) { + %ab = felt.mul %a, %b : !felt.type<"babybear">, !felt.type<"babybear"> + %abc = felt.mul %ab, %c : !felt.type<"babybear">, !felt.type<"babybear"> + %out = struct.readm %self[@out] : !struct.type<@TypedAux>, !felt.type<"babybear"> + constrain.eq %out, %abc : !felt.type<"babybear"> + function.return + } + } +} + +// CHECK-LABEL: struct.def @TypedAux +// CHECK-LABEL: function.def @compute +// CHECK: %[[AB:.*]] = felt.mul %{{.*}}, %{{.*}} : !felt.type<"babybear">, !felt.type<"babybear"> +// CHECK: struct.writem %{{.*}}[@__llzk_poly_lowering_pass_aux_member_0] = %[[AB]] : <@TypedAux>, !felt.type<"babybear"> +// CHECK-LABEL: function.def @constrain +// CHECK: %[[AB_CONSTRAIN:.*]] = felt.mul %{{.*}}, %{{.*}} : !felt.type<"babybear">, !felt.type<"babybear"> +// CHECK: %[[AUX:.*]] = struct.readm %{{.*}}[@__llzk_poly_lowering_pass_aux_member_0] : <@TypedAux>, !felt.type<"babybear"> +// CHECK: constrain.eq %[[AUX]], %[[AB_CONSTRAIN]] : !felt.type<"babybear">, !felt.type<"babybear"> +// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_0 : !felt.type<"babybear"> diff --git a/test/Transforms/R1CSLowering/r1cs_lowering_typed_aux_member.llzk b/test/Transforms/R1CSLowering/r1cs_lowering_typed_aux_member.llzk new file mode 100644 index 0000000000..3a0e0abaee --- /dev/null +++ b/test/Transforms/R1CSLowering/r1cs_lowering_typed_aux_member.llzk @@ -0,0 +1,38 @@ +// RUN: llzk-opt -split-input-file -llzk-full-r1cs-lowering --verify-each %s | FileCheck --enable-var-scope %s + +module attributes {llzk.lang} { + struct.def @TypedR1CSAux { + struct.member @out : !felt.type<"babybear"> {llzk.pub} + + function.def @compute( + %a: !felt.type<"babybear">, + %b: !felt.type<"babybear">, + %c: !felt.type<"babybear"> + ) -> !struct.type<@TypedR1CSAux> { + %self = struct.new : !struct.type<@TypedR1CSAux> + %ab = felt.mul %a, %b : !felt.type<"babybear">, !felt.type<"babybear"> + struct.writem %self[@out] = %ab : !struct.type<@TypedR1CSAux>, !felt.type<"babybear"> + function.return %self : !struct.type<@TypedR1CSAux> + } + + function.def @constrain( + %self: !struct.type<@TypedR1CSAux>, + %a: !felt.type<"babybear"> {llzk.pub}, + %b: !felt.type<"babybear">, + %c: !felt.type<"babybear"> + ) { + %ab = felt.mul %a, %b : !felt.type<"babybear">, !felt.type<"babybear"> + %bc = felt.mul %b, %c : !felt.type<"babybear">, !felt.type<"babybear"> + constrain.eq %ab, %bc : !felt.type<"babybear"> + %out = struct.readm %self[@out] : !struct.type<@TypedR1CSAux>, !felt.type<"babybear"> + constrain.eq %out, %ab : !felt.type<"babybear"> + function.return + } + } +} + +// CHECK-LABEL: r1cs.circuit @TypedR1CSAux inputs +// CHECK-SAME: %{{[0-9a-zA-Z_\.]+}}: !r1cs.signal {#r1cs.pub} +// CHECK-SAME: %{{[0-9a-zA-Z_\.]+}}: !r1cs.signal +// CHECK-SAME: %{{[0-9a-zA-Z_\.]+}}: !r1cs.signal +// CHECK: r1cs.constrain From c1f5ab5fdd2ac5185aab6cb03b450b6bf4e3655f Mon Sep 17 00:00:00 2001 From: Timothy Hoffman <4001421+tim-hoffman@users.noreply.github.com> Date: Thu, 18 Jun 2026 12:25:32 -0500 Subject: [PATCH 05/12] Instantiate wildcard type params in flattening pass and add `llzk-specialize-wildcard-arrays` pass (#538) --- .../unreleased/th__instantiate_wildcard.yaml | 5 + .../Transforms/TransformationPasses.td | 18 + .../Polymorphic/Transforms/FlatteningPass.cpp | 794 +++++++++---- .../WildcardArraySpecializationPass.cpp | 1051 +++++++++++++++++ .../Flattening/instantiate_wildcard.llzk | 252 ++++ .../specialize_wildcard.llzk | 331 ++++++ 6 files changed, 2230 insertions(+), 221 deletions(-) create mode 100644 changelogs/unreleased/th__instantiate_wildcard.yaml create mode 100644 lib/Dialect/Polymorphic/Transforms/WildcardArraySpecializationPass.cpp create mode 100644 test/Transforms/Flattening/instantiate_wildcard.llzk create mode 100644 test/Transforms/WildcardArraySpecialization/specialize_wildcard.llzk diff --git a/changelogs/unreleased/th__instantiate_wildcard.yaml b/changelogs/unreleased/th__instantiate_wildcard.yaml new file mode 100644 index 0000000000..027bdaef66 --- /dev/null +++ b/changelogs/unreleased/th__instantiate_wildcard.yaml @@ -0,0 +1,5 @@ +fixed: + - Handle wildcard CallOp template parameters in the flattening pass. + +added: + - '`llzk-specialize-wildcard-arrays` pass to refine array types with wildcard dimensions' diff --git a/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.td b/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.td index ca4ec92ea0..51c9f20dcc 100644 --- a/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.td +++ b/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.td @@ -89,4 +89,22 @@ def FlatteningPass : LLZKPass<"llzk-flatten"> { ]; } +def WildcardArraySpecializationPass + : LLZKPass<"llzk-specialize-wildcard-arrays"> { + let summary = + "Refine wildcard array casts and specialize concrete call targets"; + let description = [{ + Refines `poly.unifiable_cast` results when wildcard `array.type` dimensions can be replaced + with concrete integer sizes from the input type. Then specializes free functions, extern + declarations, and whole structs for calls whose wildcard array dimensions have become concrete. + + This pass is intended to run after `llzk-flatten`. It iterates to a fixpoint so newly-refined + cast result types can enable further callable specialization in later iterations. + }]; + let options = [Option<"iterationLimit", "max-iter", "unsigned", + /* default */ "1000", + "Maximum number of iterations before the pass gives up " + "reaching a fixpoint.">]; +} + #endif // LLZK_POLYMORPHIC_TRANSFORMATION_PASSES_TD diff --git a/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp b/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp index 909fa41a0f..a56de79236 100644 --- a/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp +++ b/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp @@ -59,6 +59,8 @@ #include #include +#include + // Include the generated base pass class definitions. namespace llzk::polymorphic { #define GEN_PASS_DEF_FLATTENINGPASS @@ -365,14 +367,29 @@ applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternS return failure(result.failed() || failureListener.hadFailure); } -template bool isConcreteAttr(Attribute a) { +/// Classifies the concreteness of an attribute value for the purposes of determining +/// if a struct instantiation can replace a parameter reference with that value. +enum class AttrConcreteness : std::uint8_t { + NonConcrete, + Concrete, + Wildcard, +}; + +/// Classify the concreteness of the given attribute value for the purposes of struct instantiation. +template AttrConcreteness classifyAttrConcreteness(Attribute a) { if (TypeAttr tyAttr = dyn_cast(a)) { - return isConcreteType(tyAttr.getValue(), AllowStructParams); + return isConcreteType(tyAttr.getValue(), AllowStructParams) ? AttrConcreteness::Concrete + : AttrConcreteness::NonConcrete; } if (IntegerAttr intAttr = dyn_cast(a)) { - return !isDynamic(intAttr); + return isDynamic(intAttr) ? AttrConcreteness::Wildcard : AttrConcreteness::Concrete; } - return false; + return AttrConcreteness::NonConcrete; +} + +/// Return true if the given attribute value is concrete for the purposes of struct instantiation. +template bool isConcreteAttr(Attribute a) { + return classifyAttrConcreteness(a) == AttrConcreteness::Concrete; } static SymbolRefAttr @@ -952,6 +969,18 @@ LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) { namespace Step1B_InstantiateFunctions { +/// Flatten nested array instantiations by appending any dimensions contributed by the converted +/// element type onto the outer array. This allows wildcard element types to resolve to +/// higher-rank arrays even though LLZK array element types cannot themselves be arrays. +static ArrayType flattenInstantiatedArrayType(ArrayType inputTy, Type convertedElemTy) { + SmallVector mergedDims(inputTy.getDimensionSizes()); + while (ArrayType nestedArrTy = llvm::dyn_cast(convertedElemTy)) { + llvm::append_range(mergedDims, nestedArrTy.getDimensionSizes()); + convertedElemTy = nestedArrTy.getElementType(); + } + return ArrayType::get(convertedElemTy, mergedDims); +} + /// TypeConverter for function instantiation that replaces TypeVarType and symbolic /// ArrayType/StructType parameters with their concrete values determined by unification. class FuncInstTypeConverter : public TypeConverter { @@ -991,7 +1020,9 @@ class FuncInstTypeConverter : public TypeConverter { if (!changed && newElemTy == inputTy.getElementType()) { return inputTy; } - return ArrayType::get(newElemTy, updated); + return flattenInstantiatedArrayType( + inputTy.cloneWith(inputTy.getElementType(), updated), newElemTy + ); }); addConversion([this](StructType inputTy) -> StructType { @@ -1026,10 +1057,294 @@ class FuncInstTypeConverter : public TypeConverter { }); } + Attribute convertAttr(Attribute attr) const { + if (TypeAttr tyAttr = llvm::dyn_cast(attr)) { + Type convertedTy = convertType(tyAttr.getValue()); + if (convertedTy != tyAttr.getValue()) { + return TypeAttr::get(convertedTy); + } + } + return convertIfPossible(attr); + } + bool containsParam(Attribute nameAttr) const { return paramNameToValue.contains(nameAttr); } const DenseMap &getParamMap() const { return paramNameToValue; } }; +/// Return the callee-side unification-derived value for a template parameter, if any. +inline static std::optional +inferUnifiedParam(const UnificationMap &unifyResult, SymbolRefAttr paramName) { + auto it = unifyResult.find({paramName, Side::RHS}); + return (it == unifyResult.end()) ? std::nullopt : std::make_optional(it->second); +} + +/// Emit the match failure used when an inferred instantiation violates a template parameter's +/// declared type restriction. +inline static LogicalResult failIncompatibleInferredParam( + CallOp op, PatternRewriter &rewriter, FlatSymbolRefAttr paramName, TemplateParamOp paramOp +) { + LLVM_DEBUG( + llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName + << "': incompatible with specified param type. MUST FAIL!\n" + ); + return rewriter.notifyMatchFailure(op, [¶mName, ¶mOp](Diagnostic &diag) { + diag.append("inferred value for parameter '") + .append(paramName) + .append("' is incompatible with specified param type") + .attachNote(paramOp.getLoc()) + .append("template parameter declared here"); + }); +} + +/// Searches a parameterized callee body for concrete type evidence that resolves a wildcard +/// template parameter, following both local unifiable casts and nested template calls. +class WildcardTypeBodyInferer final { + SymbolTableCollection &symTables_; + const DenseMap ¶mNameToConcrete_; + SmallVector> activeInferences_; + +public: + WildcardTypeBodyInferer( + SymbolTableCollection &symTables, const DenseMap ¶mNameToConcrete + ) + : symTables_(symTables), paramNameToConcrete_(paramNameToConcrete) {} + + std::optional infer(FuncDefOp func, FlatSymbolRefAttr paramName) { + if (llvm::any_of(activeInferences_, [&](const auto &e) { + return e.first == func.getOperation() && e.second == paramName; + })) { + return std::nullopt; + } + activeInferences_.emplace_back(func.getOperation(), paramName); + + FuncInstTypeConverter tyConv((paramNameToConcrete_)); + std::optional inferred; + bool ambiguous = false; + + // Record a concrete candidate unless it conflicts with an earlier one, in which + // case the wildcard is treated as ambiguous and left unresolved. + auto noteCandidate = [&inferred, &ambiguous](Attribute candidate) { + if (!candidate || !isConcreteAttr(candidate)) { + return WalkResult::advance(); + } + if (!inferred.has_value()) { + inferred = candidate; + return WalkResult::advance(); + } + if (*inferred != candidate) { + ambiguous = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }; + + WalkResult walkResult = func.walk([&](Operation *bodyOp) { + if (auto castOp = llvm::dyn_cast(bodyOp)) { + Type inputTy = tyConv.convertType(castOp.getInput().getType()); + Type resultTy = tyConv.convertType(castOp.getResult().getType()); + if (auto inputTvar = llvm::dyn_cast(inputTy); + inputTvar && inputTvar.getNameRef() == paramName && isConcreteType(resultTy)) { + return noteCandidate(TypeAttr::get(resultTy)); + } + if (auto resultTvar = llvm::dyn_cast(resultTy); + resultTvar && resultTvar.getNameRef() == paramName && isConcreteType(inputTy)) { + return noteCandidate(TypeAttr::get(inputTy)); + } + return WalkResult::advance(); + } + + auto nestedCall = llvm::dyn_cast(bodyOp); + if (!nestedCall) { + return WalkResult::advance(); + } + + FailureOr> nestedTgtOpt = + nestedCall.getCalleeTarget(symTables_); + if (failed(nestedTgtOpt)) { + return WalkResult::advance(); + } + FuncDefOp nestedTgt = nestedTgtOpt->get(); + auto nestedTemplate = llvm::dyn_cast(nestedTgt->getParentOp()); + if (!nestedTemplate) { + return WalkResult::advance(); + } + + TypeRange nestedResultTypes = nestedTgt.getFunctionType().getResults(); + for (auto [result, nestedResultTy] : + llvm::zip_equal(nestedCall.getResults(), nestedResultTypes)) { + Type convertedResultTy = tyConv.convertType(result.getType()); + auto resultTvar = llvm::dyn_cast(convertedResultTy); + auto nestedTvar = llvm::dyn_cast(nestedResultTy); + if (!resultTvar || !nestedTvar || resultTvar.getNameRef() != paramName) { + continue; + } + if (std::optional candidate = inferFromExplicitNestedCallParams( + nestedCall, nestedTemplate, nestedTvar.getNameRef(), tyConv + )) { + WalkResult candidateResult = noteCandidate(*candidate); + if (candidateResult.wasInterrupted()) { + return candidateResult; + } + continue; + } + if (std::optional candidate = infer(nestedTgt, nestedTvar.getNameRef())) { + WalkResult candidateResult = noteCandidate(*candidate); + if (candidateResult.wasInterrupted()) { + return candidateResult; + } + } + } + return WalkResult::advance(); + }); + + activeInferences_.pop_back(); + if (ambiguous || (walkResult.wasInterrupted() && !inferred.has_value())) { + return std::nullopt; + } + return inferred; + } + +private: + std::optional inferFromExplicitNestedCallParams( + CallOp nestedCall, TemplateOp nestedTemplate, FlatSymbolRefAttr nestedParamName, + const FuncInstTypeConverter &tyConv + ) const { + ArrayAttr nestedCallParams = nestedCall.getTemplateParamsAttr(); + if (isNullOrEmpty(nestedCallParams)) { + return std::nullopt; + } + + for (auto [paramOp, attr] : + llvm::zip_equal(nestedTemplate.getConstOps(), nestedCallParams)) { + auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr()); + if (paramName != nestedParamName) { + continue; + } + Attribute convertedAttr = tyConv.convertAttr(attr); + return isConcreteAttr(convertedAttr) ? std::make_optional(convertedAttr) : std::nullopt; + } + return std::nullopt; + } +}; + +/// Groups the information needed after concrete parameters have been chosen to decide whether to +/// build a full or partial instantiation and how to rewrite the call site. +struct InstantiationLayout { + SmallVector remainingNames; + std::string templateNameWithAttrs; + ArrayAttr rewrittenCallParams; +}; + +/// Derive the (partially-)instantiated template name and the remaining explicit call parameters +/// that should stay on the rewritten call. Partially-instantiated names will contain the `\x1A` +/// placeholder character at the position of a non-concrete parameter: "TemplateName_8_\x1A". +static InstantiationLayout buildInstantiationLayout( + TemplateOp parentTemplate, ArrayAttr callParams, + const DenseMap ¶mNameToConcrete +) { + SmallVector remainingNames; + SmallVector attrsForInstantiatedNameSuffix; + for (Attribute paramName : parentTemplate.getConstNames()) { + auto it = paramNameToConcrete.find(paramName); + if (it != paramNameToConcrete.end()) { + attrsForInstantiatedNameSuffix.push_back(it->second); + } else { + attrsForInstantiatedNameSuffix.push_back(nullptr); + remainingNames.push_back(paramName); + } + } + + ArrayAttr rewrittenCallParams = nullptr; + if (!isNullOrEmpty(callParams) && !remainingNames.empty()) { + SmallVector remainingCallParams; + for (auto [paramOp, attr] : + llvm::zip_equal(parentTemplate.getConstOps(), callParams.getValue())) { + auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr()); + if (!paramNameToConcrete.contains(paramName)) { + remainingCallParams.push_back(attr); + } + } + rewrittenCallParams = ArrayAttr::get(parentTemplate.getContext(), remainingCallParams); + } + + return { + std::move(remainingNames), + BuildShortTypeString::from(parentTemplate.getSymName().str(), attrsForInstantiatedNameSuffix), + rewrittenCallParams, + }; +} + +/// Rewrite cloned scalar array reads to ranged extract ops when a wildcard element type +/// resolves to a higher-rank array. +class ClonedBodyArrayReadOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ReadArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + Type newResultTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!llvm::isa(newResultTy)) { + return failure(); + } + replaceOpWithNewOp( + rewriter, op, newResultTy, adaptor.getArrRef(), adaptor.getIndices() + ); + return success(); + } +}; + +/// Rewrite cloned scalar array writes to ranged inserts when a wildcard element type +/// resolves to a higher-rank array. +class ClonedBodyArrayWriteOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + WriteArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter + ) const override { + if (!llvm::isa(adaptor.getRvalue().getType())) { + return failure(); + } + replaceOpWithNewOp( + rewriter, op, adaptor.getArrRef(), adaptor.getIndices(), adaptor.getRvalue() + ); + return success(); + } +}; + +/// Use `FuncInstTypeConverter` to apply the given substitutions from instantiation and verify +/// that `CallOp` in the converted function are valid for their respective targets (we can emit a +/// more helpful error at this point rather than discovering it later when verifying the module). +static LogicalResult applyBodyConversions( + CallOp op, FuncDefOp newFunc, const DenseMap ¶mNameToConcrete +) { + MLIRContext *ctx = op.getContext(); + FuncInstTypeConverter tyConv(paramNameToConcrete); + ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx); + target.addDynamicallyLegalOp([&tyConv](ConstReadOp p) { + // Legal if it's not in the map of concrete attribute instantiations + return !tyConv.containsParam(p.getConstNameAttr()); + }); + SmallVector delayedDiagnostics; + RewritePatternSet bodyPatterns = newGeneralRewritePatternSet(tyConv, ctx, target); + bodyPatterns.add( + tyConv, ctx, tyConv.getParamMap(), delayedDiagnostics + ); + bodyPatterns.add(tyConv, ctx); + if (failed(applyFullConversion(newFunc, target, std::move(bodyPatterns)))) { + return failure(); + } + LLVM_DEBUG(llvm::dbgs() << "[InstantiateFuncAtCallOp] instantiated clone: " << newFunc << '\n'); + ::reportDelayedDiagnostics(op, std::move(delayedDiagnostics)); + + SymbolTableCollection tables; + WalkResult res = newFunc.walk([&tables](CallOp nestedCall) { + return WalkResult(nestedCall.verifySymbolUses(tables)); + }); + return failure(res.wasInterrupted()); +} + class InstantiateFuncAtCallOp final : public OpRewritePattern { ConversionTracker &tracker_; @@ -1071,14 +1386,9 @@ class InstantiateFuncAtCallOp final : public OpRewritePattern { // middle but the overall chain does not unify. Hence, this unification may fail and should // produce a meaningful error message if it does. // See: `test/Transforms/Flattening/instantiate_funcs_fail.llzk` - FailureOr unifyResult = op.unifyTypeSignature(callTgt.getFunctionType()); + FailureOr unifyResult = unifyTypeSignature(op, callTgt, rewriter); if (failed(unifyResult)) { - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag.append("target function type does not unify with call type ") - .append(op.getTypeSignature()) - .attachNote(callTgt.getLoc()) - .append("target function declared here"); - }); + return failure(); } LLVM_DEBUG( llvm::dbgs() << "[InstantiateFuncAtCallOp] unifications of types: " @@ -1087,24 +1397,110 @@ class InstantiateFuncAtCallOp final : public OpRewritePattern { // Maps template parameter symbols to the instantiation value at the call site. DenseMap paramNameToConcrete; - // If template instantiation list is given, must use that. Otherwise, infer. + if (failed(collectConcreteTemplateParams( + op, rewriter, symTables, callTgt, parentTemplate, unifyResult.value(), + paramNameToConcrete + ))) { + return failure(); + } + + if (paramNameToConcrete.empty()) { + LLVM_DEBUG(llvm::dbgs() << "[InstantiateFuncAtCallOp] skip: no concrete params\n"); + return failure(); + } + + evaluateTemplateExprs(parentTemplate, paramNameToConcrete); + + InstantiationLayout layout = + buildInstantiationLayout(parentTemplate, op.getTemplateParamsAttr(), paramNameToConcrete); + ModuleOp parentModule = getParentOfType(parentTemplate); + assert(parentModule && "TemplateOp must be nested in a ModuleOp"); + + SymbolRefAttr originalCalleeAttr = op.getCalleeAttr(); + FailureOr newCalleeAttr = + layout.remainingNames.empty() + ? instantiateFully( + op, rewriter, symTables, callTgt, parentTemplate, parentModule, + layout.templateNameWithAttrs, paramNameToConcrete + ) + : instantiatePartially( + op, rewriter, symTables, callTgt, parentTemplate, parentModule, layout, + paramNameToConcrete + ); + if (failed(newCalleeAttr)) { + return failure(); + } + + tracker_.recordInstantiation(originalCalleeAttr); + + // Update the CallOp to point to the instantiated function and mark the module as modified. + rewriter.modifyOpInPlace(op, [&op, &newCalleeAttr, &layout]() { + LLVM_DEBUG({ + llvm::dbgs() << "[InstantiateFuncAtCallOp] updating callee from " << op.getCalleeAttr() + << " to " << *newCalleeAttr << '\n'; + }); + op.setCalleeAttr(*newCalleeAttr); + op.setTemplateParamsAttr(layout.rewrittenCallParams); + }); + tracker_.updateModifiedFlag(true); + return success(); + } + +private: + /// Re-run call/callee type unification so flattening can surface a useful error if a chain of + /// partially-instantiated calls stops unifying once earlier substitutions have been applied. + static FailureOr + unifyTypeSignature(CallOp op, FuncDefOp callTgt, PatternRewriter &rewriter) { + FailureOr unifyResult = op.unifyTypeSignature(callTgt.getFunctionType()); + if (succeeded(unifyResult)) { + return unifyResult; + } + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag.append("target function type does not unify with call type ") + .append(op.getTypeSignature()) + .attachNote(callTgt.getLoc()) + .append("target function declared here"); + }); + } + + /// Populate the concrete subset of template parameters chosen for this instantiation, using + /// explicit call-site arguments when present and otherwise relying on unification. + static LogicalResult collectConcreteTemplateParams( + CallOp op, PatternRewriter &rewriter, SymbolTableCollection &symTables, FuncDefOp callTgt, + TemplateOp parentTemplate, const UnificationMap &unifyResult, + DenseMap ¶mNameToConcrete + ) { auto realParams = parentTemplate.getConstOps(); ArrayAttr callParams = op.getTemplateParamsAttr(); LLVM_DEBUG( llvm::dbgs() << "[InstantiateFuncAtCallOp] TemplateParamsAttr: " << callParams << '\n' ); + + auto recordConcreteParam = [&](FlatSymbolRefAttr paramName, TemplateParamOp paramOp, + Attribute concreteValue) { + if (failed(op.verifyTemplateParamCompatibility(concreteValue, paramOp))) { + return failIncompatibleInferredParam(op, rewriter, paramName, paramOp); + } + paramNameToConcrete[paramName] = concreteValue; + return success(); + }; + + // If there's no template instantiation list, must infer all template parameters. if (isNullOrEmpty(callParams)) { for (auto paramOp : realParams) { auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr()); - auto it = unifyResult->find({paramName, Side::RHS}); - if (it == unifyResult->end()) { + auto inferredValOpt = inferUnifiedParam(unifyResult, paramName); + if (!inferredValOpt.has_value()) { LLVM_DEBUG( llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName << "': not found\n" ); continue; } - Attribute inferredVal = it->second; + Attribute inferredVal = *inferredValOpt; + LLVM_DEBUG( + llvm::dbgs() << "[InstantiateFuncAtCallOp] inferredVal: " << inferredVal << '\n' + ); if (!isConcreteAttr(inferredVal)) { LLVM_DEBUG( llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName @@ -1112,238 +1508,194 @@ class InstantiateFuncAtCallOp final : public OpRewritePattern { ); continue; } - // Ensure it's a valid value for the optional type restriction on the TemplateParamOp - if (failed(op.verifyTemplateParamCompatibility(inferredVal, paramOp))) { - LLVM_DEBUG( - llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName - << "': incompatible with specified param type. MUST FAIL!\n" - ); - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag.append("inferred value for parameter '") - .append(paramName) - .append("' is incompatible with specified param type") - .attachNote(paramOp.getLoc()) - .append("template parameter declared here"); - }); - } - paramNameToConcrete[paramName] = inferredVal; - } - } else { - // As stated earlier, need to run the verification checks again to ensure the - // instantiation is valid, except for the size check becuase that cannot change. - assert((callParams.size() == llvm::range_size(realParams)) && "per CallOpVerifier"); - if (failed(op.verifyTemplateParamCompatibility(realParams))) { - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag.append("incompatible with specified param type(s)"); - }); - } - if (failed(op.verifyTemplateParamsMatchInferred(realParams, unifyResult.value()))) { - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag.append("incompatible with inferred param value(s)"); - }); - } - // Add the mappings - for (auto [paramOp, attr] : llvm::zip_equal(realParams, callParams.getValue())) { - auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr()); - if (!isConcreteAttr(attr)) { - LLVM_DEBUG( - llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName - << "': not concrete, " << attr << '\n' - ); - continue; + if (failed(recordConcreteParam(paramName, paramOp, inferredVal))) { + return failure(); } - paramNameToConcrete[paramName] = attr; } + return success(); } - if (paramNameToConcrete.empty()) { - LLVM_DEBUG(llvm::dbgs() << "[InstantiateFuncAtCallOp] skip: no concrete params\n"); - return failure(); + // As stated earlier, need to run the verification checks again to ensure the + // instantiation is valid, except for the size check because that cannot change. + assert((callParams.size() == llvm::range_size(realParams)) && "per CallOpVerifier"); + if (failed(op.verifyTemplateParamCompatibility(realParams))) { + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag.append("incompatible with specified param type(s)"); + }); + } + if (failed(op.verifyTemplateParamsMatchInferred(realParams, unifyResult))) { + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag.append("incompatible with inferred param value(s)"); + }); } - // Evaluate any poly.expr symbols whose param dependencies are now concrete; add them to the - // map so ClonedFuncConstReadOpPattern can replace uses of those symbols too. - evaluateTemplateExprs(parentTemplate, paramNameToConcrete); + // When template parameters are specified on the CallOp, use them as the source of truth + // for concrete arguments, then infer wildcard parameters against the full explicit map. + SmallVector> wildcardParams; + for (auto [paramOp, attr] : llvm::zip_equal(realParams, callParams.getValue())) { + auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr()); + AttrConcreteness classification = classifyAttrConcreteness(attr); + if (classification == AttrConcreteness::Concrete) { + paramNameToConcrete[paramName] = attr; + continue; + } - // Classify each template parameter as concrete (to be inlined) or remaining (to be preserved). - SmallVector remainingNames; - SmallVector attrsForInstantiatedNameSuffix; - for (Attribute paramName : parentTemplate.getConstNames()) { - auto it = paramNameToConcrete.find(paramName); - if (it != paramNameToConcrete.end()) { - attrsForInstantiatedNameSuffix.push_back(it->second); - } else { - attrsForInstantiatedNameSuffix.push_back(nullptr); // placeholder for non-concrete param - remainingNames.push_back(paramName); + if (classification == AttrConcreteness::NonConcrete) { + LLVM_DEBUG( + llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName + << "': not concrete, " << attr << '\n' + ); + continue; } + wildcardParams.emplace_back(paramOp, paramName); } - MLIRContext *ctx = op.getContext(); - ModuleOp parentModule = getParentOfType(parentTemplate); - assert(parentModule && "TemplateOp must be nested in a ModuleOp"); - - // Build the (partially-)instantiated template name, e.g., "TemplateName_8_\x1A" where \x1A - // is a placeholder character at the position of a non-concrete parameter. - std::string templateNameWithAttrs = BuildShortTypeString::from( - parentTemplate.getSymName().str(), attrsForInstantiatedNameSuffix - ); - - // Helper lambda to: - // 1. build the FuncInstTypeConverter and apply it to a cloned function - // 2. verify CallOp in the converted function are valid for their respective targets - // and emit a more helpful error at this point rather than discovering it later - // when verifying the entire module. - auto applyBodyConversions = [&](FuncDefOp newFunc) -> LogicalResult { - FuncInstTypeConverter tyConv(paramNameToConcrete); - ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx); - target.addDynamicallyLegalOp([&tyConv](ConstReadOp p) { - // Legal if it's not in the map of concrete attribute instantiations - return !tyConv.containsParam(p.getConstNameAttr()); - }); - SmallVector delayedDiagnostics; - RewritePatternSet bodyPatterns = newGeneralRewritePatternSet(tyConv, ctx, target); - bodyPatterns.add( - tyConv, ctx, tyConv.getParamMap(), delayedDiagnostics - ); - if (failed(applyFullConversion(newFunc, target, std::move(bodyPatterns)))) { - return failure(); + WildcardTypeBodyInferer bodyInferer(symTables, paramNameToConcrete); + for (auto [paramOp, paramName] : wildcardParams) { + auto inferredValOpt = inferUnifiedParam(unifyResult, paramName); + if (inferredValOpt.has_value() && isConcreteAttr(*inferredValOpt)) { + LLVM_DEBUG( + llvm::dbgs() << "[InstantiateFuncAtCallOp] inferredVal: " << *inferredValOpt << '\n' + ); + if (failed(recordConcreteParam(paramName, paramOp, *inferredValOpt))) { + return failure(); + } + continue; } - LLVM_DEBUG( - llvm::dbgs() << "[InstantiateFuncAtCallOp] instantiated clone: " << newFunc << '\n' - ); - ::reportDelayedDiagnostics(op, std::move(delayedDiagnostics)); - // Verify CallOp match targets - SymbolTableCollection tables; - WalkResult res = newFunc.walk([&tables](CallOp nestedCall) { - return WalkResult(nestedCall.verifySymbolUses(tables)); - }); - return failure(res.wasInterrupted()); - }; - - SmallVector symPieces = getPieces(op.getCalleeAttr()); - assert(symPieces.size() >= 2 && "callee must include at least template and function names"); - SymbolRefAttr originalCalleeAttr = asSymbolRefAttr(symPieces); - if (remainingNames.empty()) { - // FULL INSTANTIATION: place the cloned function directly in the parent module. - // New function name encodes all parameter values, e.g., "TemplateName_8_12_funcName". - std::string newFuncName = - (mlir::Twine(templateNameWithAttrs) + "_" + callTgt.getSymName()).str(); - StringRef actualNewFuncName = newFuncName; - if (!symTables.getSymbolTable(parentModule).lookup(newFuncName)) { - FuncDefOp newFunc = callTgt.clone(); - newFunc.setSymName(newFuncName); - convertCalleesInPlace(newFunc, paramNameToConcrete); - // Insert before the TemplateOp; symbol table may adjust the name to ensure uniqueness. - symTables.getSymbolTable(parentModule).insert(newFunc, Block::iterator(parentTemplate)); - actualNewFuncName = newFunc.getSymName(); + inferredValOpt = bodyInferer.infer(callTgt, paramName); + if (inferredValOpt.has_value() && isConcreteAttr(*inferredValOpt)) { LLVM_DEBUG( - llvm::dbgs() << "[InstantiateFuncAtCallOp] created full instantiation function: " - << actualNewFuncName << '\n' + llvm::dbgs() << "[InstantiateFuncAtCallOp] body-inferred value for param '" + << paramName << "': " << *inferredValOpt << '\n' ); - if (failed(applyBodyConversions(newFunc))) { - LLVM_DEBUG( - llvm::dbgs() << "[InstantiateFuncAtCallOp] body conversion failed for " - << actualNewFuncName << '\n' - ); - newFunc->erase(); - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag.append("failure while creating instantiated function '", actualNewFuncName, '\''); - }); + if (failed(recordConcreteParam(paramName, paramOp, *inferredValOpt))) { + return failure(); } - } else { + } + } + return success(); + } + + /// Create or reuse a fully-instantiated clone in the parent module and return the rewritten + /// module-level callee reference. + static FailureOr instantiateFully( + CallOp op, PatternRewriter &rewriter, SymbolTableCollection &symTables, FuncDefOp callTgt, + TemplateOp parentTemplate, ModuleOp parentModule, StringRef templateNameWithAttrs, + const DenseMap ¶mNameToConcrete + ) { + MLIRContext *ctx = op.getContext(); + std::string newFuncName = + (mlir::Twine(templateNameWithAttrs) + "_" + callTgt.getSymName()).str(); + StringRef actualNewFuncName = newFuncName; + if (!symTables.getSymbolTable(parentModule).lookup(newFuncName)) { + FuncDefOp newFunc = callTgt.clone(); + newFunc.setSymName(newFuncName); + convertCalleesInPlace(newFunc, paramNameToConcrete); + // Insert before the TemplateOp; symbol table may adjust the name to ensure uniqueness. + symTables.getSymbolTable(parentModule).insert(newFunc, Block::iterator(parentTemplate)); + actualNewFuncName = newFunc.getSymName(); + LLVM_DEBUG( + llvm::dbgs() << "[InstantiateFuncAtCallOp] created full instantiation function: " + << actualNewFuncName << '\n' + ); + if (failed(applyBodyConversions(op, newFunc, paramNameToConcrete))) { LLVM_DEBUG( - llvm::dbgs() << "[InstantiateFuncAtCallOp] reusing full instantiation function: " + llvm::dbgs() << "[InstantiateFuncAtCallOp] body conversion failed for " << actualNewFuncName << '\n' ); + newFunc->erase(); + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag.append("failure while creating instantiated function '", actualNewFuncName, '\''); + }); } - // Callee: drop template & original function names, add the new module-level function name. - // Original: @[prefix...]::@TemplateName::@funcName - // New: @[prefix...]::@newFuncName - symPieces.pop_back(); // remove original function name - symPieces.pop_back(); // remove template name - symPieces.push_back(FlatSymbolRefAttr::get(StringAttr::get(ctx, actualNewFuncName))); } else { - // PARTIAL INSTANTIATION: place the cloned function in a new partially-instantiated - // TemplateOp that retains only the non-concrete parameters. - // New template name encodes the concrete values and uses placeholder chars for the rest, - // e.g., "TemplateName_8_\x1A" where \x1A marks the position of a non-concrete param. - TemplateOp newTemplate; - if (Operation *existing = - symTables.getSymbolTable(parentModule).lookup(templateNameWithAttrs)) { - newTemplate = llvm::dyn_cast(existing); - } - if (!newTemplate) { - // Clone the TemplateOp structure without its body and set the new name. - newTemplate = parentTemplate.cloneWithoutRegions(); - newTemplate.setSymName(templateNameWithAttrs); - assert(newTemplate->getNumRegions() > 0 && "region exists"); - newTemplate.getBodyRegion().emplaceBlock(); - - // Clone the preserved (non-concrete) param/expr ops into the new template in order. - Block &newTemplateBody = newTemplate.getBodyRegion().front(); - for (Attribute name : remainingNames) { - FlatSymbolRefAttr nameSym = llvm::cast(name); - Operation *paramOp = symTables.getSymbolTable(parentTemplate).lookup(nameSym.getAttr()); - assert(paramOp && "symbol must exist"); - newTemplateBody.push_back(paramOp->clone()); - } + LLVM_DEBUG( + llvm::dbgs() << "[InstantiateFuncAtCallOp] reusing full instantiation function: " + << actualNewFuncName << '\n' + ); + } - // Clone and partially convert the function (concretize only the concrete params). - FuncDefOp newFunc = callTgt.clone(); - convertCalleesInPlace(newFunc, paramNameToConcrete); + // Callee: drop template & original function names, add the new module-level function name. + // Original: @[prefix...]::@TemplateName::@funcName + // New: @[prefix...]::@newFuncName + SmallVector symPieces = getPieces(op.getCalleeAttr()); + assert(symPieces.size() >= 2 && "callee must include at least template and function names"); + symPieces.pop_back(); // remove original function name + symPieces.pop_back(); // remove template name + symPieces.push_back(FlatSymbolRefAttr::get(StringAttr::get(ctx, actualNewFuncName))); + return asSymbolRefAttr(symPieces); + } + + /// Create or reuse a partially-instantiated template that preserves the remaining non-concrete + /// parameters and return the rewritten nested callee reference. + /// New template name encodes the concrete values and uses placeholder chars for the rest, + /// e.g., "TemplateName_8_\x1A" where \x1A marks the position of a non-concrete param. + static FailureOr instantiatePartially( + CallOp op, PatternRewriter &rewriter, SymbolTableCollection &symTables, FuncDefOp callTgt, + TemplateOp parentTemplate, ModuleOp parentModule, const InstantiationLayout &layout, + const DenseMap ¶mNameToConcrete + ) { + TemplateOp newTemplate; + if (Operation *existing = + symTables.getSymbolTable(parentModule).lookup(layout.templateNameWithAttrs)) { + newTemplate = llvm::dyn_cast(existing); + } + if (!newTemplate) { + newTemplate = parentTemplate.cloneWithoutRegions(); + newTemplate.setSymName(layout.templateNameWithAttrs); + assert(newTemplate->getNumRegions() > 0 && "region exists"); + newTemplate.getBodyRegion().emplaceBlock(); - // Insert before body conversion so nested concrete callees verify from the root module. Use - // the `SymbolTable::insert()` function so that the name will be made unique if necessary. - symTables.getSymbolTable(newTemplate).insert(newFunc); - symTables.getSymbolTable(parentModule).insert(newTemplate, Block::iterator(parentTemplate)); - if (failed(applyBodyConversions(newFunc))) { - StringRef newFuncName = newFunc.getSymName(); - LLVM_DEBUG( - llvm::dbgs() << "[InstantiateFuncAtCallOp] body conversion failed for " - << newFuncName << '\n' - ); - newTemplate->erase(); - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag.append("failure while creating instantiated function '", newFuncName, '\''); - }); - } + Block &newTemplateBody = newTemplate.getBodyRegion().front(); + for (Attribute name : layout.remainingNames) { + FlatSymbolRefAttr nameSym = llvm::cast(name); + Operation *paramOp = symTables.getSymbolTable(parentTemplate).lookup(nameSym.getAttr()); + assert(paramOp && "symbol must exist"); + newTemplateBody.push_back(paramOp->clone()); + } + // Clone and partially convert the function (concretize only the concrete params). + FuncDefOp newFunc = callTgt.clone(); + convertCalleesInPlace(newFunc, paramNameToConcrete); + + // Insert before body conversion so nested concrete callees verify from the root module. Use + // the `SymbolTable::insert()` function so that the name will be made unique if necessary. + symTables.getSymbolTable(newTemplate).insert(newFunc); + symTables.getSymbolTable(parentModule).insert(newTemplate, Block::iterator(parentTemplate)); + if (failed(applyBodyConversions(op, newFunc, paramNameToConcrete))) { + StringRef newFuncName = newFunc.getSymName(); LLVM_DEBUG( - llvm::dbgs() << "[InstantiateFuncAtCallOp] created partial instantiation template: " - << newTemplate.getSymName() << '\n' - ); - } else { - LLVM_DEBUG( - llvm::dbgs() << "[InstantiateFuncAtCallOp] reusing partial instantiation template: " - << newTemplate.getSymName() << '\n' + llvm::dbgs() << "[InstantiateFuncAtCallOp] body conversion failed for " << newFuncName + << '\n' ); + newTemplate->erase(); + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag.append("failure while creating instantiated function '", newFuncName, '\''); + }); } - // Callee: replace old template name with new template name, keep the function name. - // Original: @[prefix...]::@TemplateName::@funcName - // New: @[prefix...]::@newTemplateName::@funcName - symPieces.pop_back(); // remove original function name (will be re-appended) - symPieces.pop_back(); // remove original template name - symPieces.push_back(FlatSymbolRefAttr::get(newTemplate.getSymNameAttr())); - symPieces.push_back(FlatSymbolRefAttr::get(callTgt.getSymNameAttr())); - } - tracker_.recordInstantiation(originalCalleeAttr); + LLVM_DEBUG( + llvm::dbgs() << "[InstantiateFuncAtCallOp] created partial instantiation template: " + << newTemplate.getSymName() << '\n' + ); + } else { + LLVM_DEBUG( + llvm::dbgs() << "[InstantiateFuncAtCallOp] reusing partial instantiation template: " + << newTemplate.getSymName() << '\n' + ); + } - // Update the CallOp to point to the instantiated function and mark the module as modified. - rewriter.modifyOpInPlace(op, [&op, &symPieces]() { - // Update callee attribute. - SymbolRefAttr newCalleeAttr = asSymbolRefAttr(symPieces); - LLVM_DEBUG({ - llvm::dbgs() << "[InstantiateFuncAtCallOp] updating callee from " << op.getCalleeAttr() - << " to " << newCalleeAttr << '\n'; - }); - op.setCalleeAttr(newCalleeAttr); - // Also drop template param list. If it was present, it was fully used (no partial case). - op.setTemplateParamsAttr(nullptr); - }); - tracker_.updateModifiedFlag(true); - return success(); + // Callee: replace old template name with new template name, keep the function name. + // Original: @[prefix...]::@TemplateName::@funcName + // New: @[prefix...]::@newTemplateName::@funcName + SmallVector symPieces = getPieces(op.getCalleeAttr()); + assert(symPieces.size() >= 2 && "callee must include at least template and function names"); + symPieces.pop_back(); // remove original function name (will be re-appended) + symPieces.pop_back(); // remove original template name + symPieces.push_back(FlatSymbolRefAttr::get(newTemplate.getSymNameAttr())); + symPieces.push_back(FlatSymbolRefAttr::get(callTgt.getSymNameAttr())); + return asSymbolRefAttr(symPieces); } }; @@ -1696,7 +2048,7 @@ class InstantiateAtCallOpCompute final : public OpRewritePattern { Attribute fromCall = std::get<1>(p); // Preserve attributes that are already concrete at the call site. Otherwise attempt to lookup // non-parameterized concrete unification for the target struct parameter symbol. - if (!isConcreteAttr<>(fromCall)) { + if (!isConcreteAttr(fromCall)) { Attribute fromTgt = std::get<0>(p); LLVM_DEBUG({ llvm::dbgs() << "[instantiateViaTargetType] fromCall = " << fromCall << '\n'; diff --git a/lib/Dialect/Polymorphic/Transforms/WildcardArraySpecializationPass.cpp b/lib/Dialect/Polymorphic/Transforms/WildcardArraySpecializationPass.cpp new file mode 100644 index 0000000000..d38bcec1a8 --- /dev/null +++ b/lib/Dialect/Polymorphic/Transforms/WildcardArraySpecializationPass.cpp @@ -0,0 +1,1051 @@ +//===-- WildcardArraySpecializationPass.cpp ---------------------*- C++ -*-===// +// +// Part of the LLZK Project, under the Apache License v2.0. +// See LICENSE.txt for license information. +// Copyright 2026 Project LLZK +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the `-llzk-specialize-wildcard-arrays` pass. +/// +//===----------------------------------------------------------------------===// + +#include "SharedImpl.h" + +#include "llzk/Analysis/SymbolDefTree.h" +#include "llzk/Analysis/SymbolUseGraph.h" +#include "llzk/Dialect/Array/IR/Ops.h" +#include "llzk/Dialect/Cast/IR/Ops.h" +#include "llzk/Dialect/Function/IR/Ops.h" +#include "llzk/Dialect/LLZK/IR/AttributeHelper.h" +#include "llzk/Dialect/Polymorphic/IR/Ops.h" +#include "llzk/Dialect/Polymorphic/Transforms/TransformationPasses.h" +#include "llzk/Dialect/Shared/TypeConversionPatterns.h" +#include "llzk/Dialect/Struct/IR/Ops.h" +#include "llzk/Util/Debug.h" +#include "llzk/Util/SymbolHelper.h" +#include "llzk/Util/SymbolLookup.h" +#include "llzk/Util/SymbolTableLLZK.h" +#include "llzk/Util/TypeHelper.h" + +#include +#include +#include +#include +#include + +// Include the generated base pass class definitions. +namespace llzk::polymorphic { +#define GEN_PASS_DEF_WILDCARDARRAYSPECIALIZATIONPASS +#include "llzk/Dialect/Polymorphic/Transforms/TransformationPasses.h.inc" +} // namespace llzk::polymorphic + +#define DEBUG_TYPE "llzk-specialize-wildcard-arrays" + +using namespace mlir; +using namespace llzk; +using namespace llzk::array; +using namespace llzk::component; +using namespace llzk::function; +using namespace llzk::polymorphic; +using namespace llzk::polymorphic::detail; + +namespace { + +/// Tracks specializations created during the pass so later rewrites can validate +/// that any type change is a legal move toward a more concrete program. +class ConversionTracker { + bool modified = false; + DenseMap> structSpecializations; + DenseMap reverseSpecializations; + DenseSet funcInstantiations; + +public: + bool isModified() const { return modified; } + void resetModifiedFlag() { modified = false; } + void updateModifiedFlag(bool currStepModified) { modified |= currStepModified; } + + void recordInstantiation(SymbolRefAttr funcName) { + funcInstantiations.insert(funcName); + modified = true; + } + + void recordSpecialization(StructType oldType, StructType newType) { + assert(isNullOrEmpty(oldType.getParams()) && "wildcard-array specialization expects plain key"); + SmallVector &specializations = structSpecializations[oldType]; + if (llvm::is_contained(specializations, newType)) { + assert(reverseSpecializations.lookup(newType) == oldType); + return; + } + specializations.push_back(newType); + auto [it, inserted] = reverseSpecializations.try_emplace(newType, oldType); + (void)it; + (void)inserted; + assert(inserted || it->second == oldType); + modified = true; + } + + DenseSet getInstantiatedDefinitionNames() const { + DenseSet instantiatedNames = funcInstantiations; + for (const auto &[origRemoteTy, _] : structSpecializations) { + instantiatedNames.insert(origRemoteTy.getNameRef()); + } + return instantiatedNames; + } + + bool isLegalConversion(Type oldType, Type newType, const char *patName) const { + std::function checkSpecializations = [&](Type oTy, Type nTy) { + if (StructType oldStructType = llvm::dyn_cast(oTy)) { + auto specializationIt = structSpecializations.find(oldStructType); + if (specializationIt != structSpecializations.end() && + llvm::is_contained(specializationIt->second, nTy)) { + return true; + } + } + if (StructType newStructType = llvm::dyn_cast(nTy)) { + if (StructType preImage = reverseSpecializations.lookup(newStructType)) { + if (isMoreConcreteUnification(oTy, preImage, checkSpecializations)) { + return true; + } + } + } + return false; + }; + + if (!isMoreConcreteUnification(oldType, newType, checkSpecializations)) { + LLVM_DEBUG({ + llvm::dbgs() << '[' << patName << "] invalid type conversion from " << oldType << " to " + << newType << '\n'; + }); + return false; + } + return true; + } + + bool areLegalConversions(TypeRange oldTypes, TypeRange newTypes, const char *patName) const { + return oldTypes.size() == newTypes.size() && + llvm::all_of(llvm::zip_equal(oldTypes, newTypes), [&](auto pair) { + return isLegalConversion(std::get<0>(pair), std::get<1>(pair), patName); + }); + } +}; + +/// Turns pattern match failures into hard pass failures with diagnostics. +struct MatchFailureListener : public RewriterBase::Listener { + bool hadFailure = false; + + void notifyMatchFailure(Location loc, function_ref reasonCallback) override { + InFlightDiagnostic diag = emitError(loc); + reasonCallback(*diag.getUnderlyingDiagnostic()); + diag.report(); + hadFailure = true; + } +}; + +static LogicalResult +applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternSet &&patterns) { + bool currStepModified = false; + MatchFailureListener failureListener; + LogicalResult result = applyPatternsGreedily( + modOp->getRegion(0), std::move(patterns), + GreedyRewriteConfig {.maxIterations = 20, .listener = &failureListener, .fold = true}, + &currStepModified + ); + tracker.updateModifiedFlag(currStepModified); + return failure(result.failed() || failureListener.hadFailure); +} + +/// Records every wildcard array type replaced while specializing one call site. +struct WildcardArraySpecializationInfo { + DenseMap replacements; + SmallVector> ordered; + bool hasConflictingReplacements = false; + + bool empty() const { return ordered.empty(); } + + LogicalResult record(ArrayType oldTy, ArrayType newTy) { + ordered.emplace_back(oldTy, newTy); + auto it = replacements.find(oldTy); + if (it == replacements.end()) { + replacements.try_emplace(oldTy, newTy); + return success(); + } + hasConflictingReplacements |= it->second != newTy; + return success(); + } + + SmallVector getConcreteTypeAttrs() const { + SmallVector attrs; + attrs.reserve(ordered.size()); + for (const auto &[_, newTy] : ordered) { + attrs.push_back(TypeAttr::get(newTy)); + } + return attrs; + } +}; + +static void updateFuncSignature(FuncDefOp func, FunctionType newFuncTy) { + FunctionType oldFuncTy = func.getFunctionType(); + if (oldFuncTy == newFuncTy) { + return; + } + + func.setFunctionType(newFuncTy); + Region &body = func.getFunctionBody(); + if (body.empty()) { + return; + } + + Block &entryBlock = body.front(); + assert(entryBlock.getNumArguments() == newFuncTy.getNumInputs() && "function arity changed"); + for (auto [arg, newTy] : llvm::zip_equal(entryBlock.getArguments(), newFuncTy.getInputs())) { + arg.setType(newTy); + } +} + +/// Returns whether `type` contains an `array.type` with at least one dynamic +/// dimension anywhere in the nested type structure. +static bool containsWildcardArrayDims(Type type) { + if (ArrayType arrTy = llvm::dyn_cast(type)) { + if (llvm::any_of(arrTy.getDimensionSizes(), [](Attribute dim) { + if (IntegerAttr intAttr = llvm::dyn_cast(dim)) { + return isDynamic(intAttr); + } + return false; + })) { + return true; + } + return containsWildcardArrayDims(arrTy.getElementType()); + } + if (StructType structTy = llvm::dyn_cast(type)) { + if (ArrayAttr params = structTy.getParams()) { + return llvm::any_of(params.getValue(), [](Attribute attr) { + if (TypeAttr typeAttr = llvm::dyn_cast(attr)) { + return containsWildcardArrayDims(typeAttr.getValue()); + } + return false; + }); + } + } + if (FunctionType funcTy = llvm::dyn_cast(type)) { + return llvm::any_of(funcTy.getInputs(), containsWildcardArrayDims) || + llvm::any_of(funcTy.getResults(), containsWildcardArrayDims); + } + return false; +} + +/// Collects every wildcard array in `oldTy` that becomes concrete in `newTy`. +/// Returns failure if the types do not describe the same overall structure. +static LogicalResult collectWildcardArraySpecializations( + Type oldTy, Type newTy, WildcardArraySpecializationInfo &out, + std::optional ignoredStructType = std::nullopt +) { + if (ignoredStructType.has_value() && oldTy == *ignoredStructType && + llvm::isa(newTy)) { + return success(); + } + if (!typesUnify(oldTy, newTy)) { + return failure(); + } + if (FunctionType oldFuncTy = llvm::dyn_cast(oldTy)) { + FunctionType newFuncTy = llvm::dyn_cast(newTy); + if (!newFuncTy || oldFuncTy.getNumInputs() != newFuncTy.getNumInputs() || + oldFuncTy.getNumResults() != newFuncTy.getNumResults()) { + return failure(); + } + for (auto [oldInput, newInput] : + llvm::zip_equal(oldFuncTy.getInputs(), newFuncTy.getInputs())) { + if (failed(collectWildcardArraySpecializations(oldInput, newInput, out, ignoredStructType))) { + return failure(); + } + } + for (auto [oldResult, newResult] : + llvm::zip_equal(oldFuncTy.getResults(), newFuncTy.getResults())) { + if (failed( + collectWildcardArraySpecializations(oldResult, newResult, out, ignoredStructType) + )) { + return failure(); + } + } + return success(); + } + if (StructType oldStructTy = llvm::dyn_cast(oldTy)) { + StructType newStructTy = llvm::dyn_cast(newTy); + if (!newStructTy) { + return failure(); + } + ArrayAttr oldParams = oldStructTy.getParams(); + ArrayAttr newParams = newStructTy.getParams(); + ArrayRef oldAttrs = oldParams ? oldParams.getValue() : ArrayRef {}; + ArrayRef newAttrs = newParams ? newParams.getValue() : ArrayRef {}; + if (oldAttrs.size() != newAttrs.size()) { + return failure(); + } + for (auto [oldAttr, newAttr] : llvm::zip_equal(oldAttrs, newAttrs)) { + if (TypeAttr oldTypeAttr = llvm::dyn_cast(oldAttr)) { + TypeAttr newTypeAttr = llvm::dyn_cast(newAttr); + if (!newTypeAttr || + failed(collectWildcardArraySpecializations( + oldTypeAttr.getValue(), newTypeAttr.getValue(), out, ignoredStructType + ))) { + return failure(); + } + } + } + return success(); + } + ArrayType oldArrTy = llvm::dyn_cast(oldTy); + ArrayType newArrTy = llvm::dyn_cast(newTy); + if (!oldArrTy || !newArrTy) { + return success(); + } + if (oldArrTy.getDimensionSizes().size() != newArrTy.getDimensionSizes().size() || + failed(collectWildcardArraySpecializations( + oldArrTy.getElementType(), newArrTy.getElementType(), out, ignoredStructType + ))) { + return failure(); + } + + bool changed = false; + for (auto [oldDim, newDim] : + llvm::zip_equal(oldArrTy.getDimensionSizes(), newArrTy.getDimensionSizes())) { + if (auto oldInt = llvm::dyn_cast(oldDim); oldInt && isDynamic(oldInt)) { + if (auto newInt = llvm::dyn_cast(newDim); newInt && !isDynamic(newInt)) { + changed = true; + } + } + } + if (!changed) { + return success(); + } + return out.record(oldArrTy, newArrTy); +} + +static bool functionTypeIsMoreConcrete( + FunctionType oldTy, FunctionType newTy, const ConversionTracker &tracker, const char *patName, + std::optional ignoredStructType = std::nullopt +) { + auto isCompatible = [&](Type oldType, Type newType) { + if (ignoredStructType.has_value() && oldType == *ignoredStructType && + llvm::isa(newType)) { + return true; + } + return tracker.isLegalConversion(oldType, newType, patName); + }; + + return oldTy.getNumInputs() == newTy.getNumInputs() && + oldTy.getNumResults() == newTy.getNumResults() && + llvm::all_of(llvm::zip_equal(oldTy.getInputs(), newTy.getInputs()), [&](auto pair) { + return isCompatible(std::get<0>(pair), std::get<1>(pair)); + }) && llvm::all_of(llvm::zip_equal(oldTy.getResults(), newTy.getResults()), [&](auto pair) { + return isCompatible(std::get<0>(pair), std::get<1>(pair)); + }); +} + +/// Rewrites wildcard arrays, and optionally one enclosing struct type, to the +/// concrete types inferred for a specialization. +class WildcardArrayTypeConverter : public TypeConverter { + const DenseMap &arrayReplacements_; + std::optional oldStructType_; + std::optional newStructType_; + +public: + WildcardArrayTypeConverter( + const DenseMap &arrayReplacements, + std::optional oldStructType = std::nullopt, + std::optional newStructType = std::nullopt + ) + : TypeConverter(), arrayReplacements_(arrayReplacements), oldStructType_(oldStructType), + newStructType_(newStructType) { + addConversion([](Type inputTy) { return inputTy; }); + + addConversion([this](ArrayType inputTy) -> Type { + Type newElemTy = this->convertType(inputTy.getElementType()); + auto it = arrayReplacements_.find(inputTy); + if (it != arrayReplacements_.end()) { + ArrayType replacement = it->second; + if (replacement.getElementType() != newElemTy) { + return replacement.cloneWith(newElemTy); + } + return replacement; + } + if (newElemTy != inputTy.getElementType()) { + return inputTy.cloneWith(newElemTy); + } + return inputTy; + }); + + addConversion([this](StructType inputTy) -> Type { + if (oldStructType_.has_value() && newStructType_.has_value() && inputTy == *oldStructType_) { + return *newStructType_; + } + if (ArrayAttr params = inputTy.getParams()) { + SmallVector updated; + bool changed = false; + for (Attribute attr : params.getValue()) { + if (TypeAttr typeAttr = llvm::dyn_cast(attr)) { + Type newTy = this->convertType(typeAttr.getValue()); + updated.push_back(TypeAttr::get(newTy)); + changed |= newTy != typeAttr.getValue(); + } else { + updated.push_back(attr); + } + } + if (changed) { + return StructType::get( + inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated) + ); + } + } + return inputTy; + }); + } +}; + +/// Produces a more precise cast result type when the input carries concrete +/// array sizes for dimensions that were wildcarded in the result. +static std::optional refineCastResultArrayWildcards(Type resultTy, Type inputTy) { + ArrayType resultArrTy = llvm::dyn_cast(resultTy); + ArrayType inputArrTy = llvm::dyn_cast(inputTy); + if (!resultArrTy || !inputArrTy) { + return std::nullopt; + } + if (resultArrTy.getDimensionSizes().size() != inputArrTy.getDimensionSizes().size() || + !typesUnify(resultArrTy.getElementType(), inputArrTy.getElementType())) { + return std::nullopt; + } + + SmallVector refinedDims; + bool changed = false; + for (auto [resultDim, inputDim] : + llvm::zip_equal(resultArrTy.getDimensionSizes(), inputArrTy.getDimensionSizes())) { + if (auto resultInt = llvm::dyn_cast(resultDim); + resultInt && isDynamic(resultInt)) { + if (auto inputInt = llvm::dyn_cast(inputDim); inputInt && !isDynamic(inputInt)) { + refinedDims.push_back(inputDim); + changed = true; + continue; + } + } + refinedDims.push_back(resultDim); + } + + if (!changed) { + return std::nullopt; + } + return resultArrTy.cloneWith(resultArrTy.getElementType(), refinedDims); +} + +/// Template-scoped references name template parameters rather than concrete +/// symbols, so they cannot be specialized here. +static bool calleeReferencesTemplateParam(CallOp op) { + SymbolRefAttr callee = op.getCalleeAttr(); + if (!callee || callee.getNestedReferences().size() != 1) { + return false; + } + TemplateOp parentTemplate = getParentOfType(op); + if (!parentTemplate) { + return false; + } + return parentTemplate.hasConstNamed(callee.getRootReference()); +} + +namespace Cleanup { + +/// Shared state for post-specialization cleanup helpers. +class CleanupBase { +public: + SymbolTableCollection tables; + + CleanupBase(ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph) + : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {} + +protected: + ModuleOp rootMod; + const SymbolDefTree &defTree; + const SymbolUseGraph &useGraph; +}; + +static bool isErasableDefinition(Operation *op) { + if (llvm::isa(op)) { + return true; + } + if (function::FuncDefOp fdef = llvm::dyn_cast(op)) { + return !fdef.isInStruct(); + } + return false; +} + +/// Removes parameterized definitions whose specialized replacements now cover +/// every remaining use. +struct FromEraseSet : public CleanupBase { + FromEraseSet( + ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph, + DenseSet &&tryToErasePaths + ) + : CleanupBase(root, symDefTree, symUseGraph) { + for (SymbolRefAttr path : tryToErasePaths) { + Operation *lookupFrom = rootMod.getOperation(); + auto res = lookupSymbolIn(tables, path, Within(), lookupFrom); + assert(succeeded(res) && "inputs must be valid symbol references"); + assert(isErasableDefinition(res->get()) && "inputs must be cleanup candidates"); + if (!res->viaInclude()) { + tryToErase.insert(llvm::cast(res->get())); + } + } + } + + LogicalResult eraseUnusedDefinitions() { + for (SymbolOpInterface sym : tryToErase) { + collectSafeToErase(sym); + } + for (auto &it : llvm::make_early_inc_range(visitedPlusSafetyResult)) { + if (!it.second || !tryToErase.contains(it.first)) { + visitedPlusSafetyResult.erase(it.first); + } + } + for (auto &[sym, _] : visitedPlusSafetyResult) { + sym.erase(); + } + return success(); + } + + const DenseSet &getTryToEraseSet() const { return tryToErase; } + +private: + DenseSet tryToErase; + DenseMap visitedPlusSafetyResult; + DenseMap lookupCache; + + bool collectSafeToErase(SymbolOpInterface check) { + assert(check); + auto visited = visitedPlusSafetyResult.find(check); + if (visited != visitedPlusSafetyResult.end()) { + return visited->second; + } + if (isErasableDefinition(check.getOperation()) && !tryToErase.contains(check)) { + visitedPlusSafetyResult[check] = false; + return false; + } + visitedPlusSafetyResult[check] = true; + if (collectSafeToErase(defTree.lookupNode(check))) { + const auto *useNode = useGraph.lookupNode(check); + if (!useNode || collectSafeToErase(useNode)) { + return true; + } + } + visitedPlusSafetyResult[check] = false; + return false; + } + + bool collectSafeToErase(const SymbolDefTreeNode *check) { + assert(check); + if (const SymbolDefTreeNode *p = check->getParent()) { + if (SymbolOpInterface checkOp = p->getOp()) { + return collectSafeToErase(checkOp); + } + } + return true; + } + + bool collectSafeToErase(const SymbolUseGraphNode *check) { + assert(check); + for (const SymbolUseGraphNode *p : check->predecessorIter()) { + if (SymbolOpInterface checkOp = cachedLookup(p)) { + if (!collectSafeToErase(checkOp)) { + return false; + } + } + } + return true; + } + + SymbolOpInterface cachedLookup(const SymbolUseGraphNode *node) { + assert(node && "must provide a node"); + auto fromCache = lookupCache.find(node); + if (fromCache != lookupCache.end()) { + return fromCache->second; + } + auto lookupRes = node->lookupSymbol(tables); + assert(succeeded(lookupRes) && "graph contains node with invalid path"); + assert(lookupRes->get() != nullptr && "lookup must return an Operation"); + SymbolOpInterface actualRes = + lookupRes->viaInclude() ? nullptr : llvm::cast(lookupRes->get()); + lookupCache[node] = actualRes; + return actualRes; + } +}; + +} // namespace Cleanup + +static LogicalResult erasePreimageOfInstantiations( + ModuleOp rootMod, const ConversionTracker &tracker, const SymbolDefTree &symDefTree, + const SymbolUseGraph &symUseGraph +) { + Cleanup::FromEraseSet cleaner( + rootMod, symDefTree, symUseGraph, tracker.getInstantiatedDefinitionNames() + ); + LogicalResult res = cleaner.eraseUnusedDefinitions(); + if (failed(res)) { + return res; + } + rootMod->walk([&cleaner, &symUseGraph](Operation *walkedOp) { + SymbolOpInterface op = llvm::dyn_cast(walkedOp); + if (!op || !cleaner.getTryToEraseSet().contains(op)) { + return; + } + if (const SymbolUseGraphNode *node = symUseGraph.lookupNode(op); + node && node->hasPredecessor()) { + op.emitWarning("Parameterized definition still has uses!").report(); + } + }); + return success(); +} + +namespace CastRefinement { + +/// Refines `poly.unifiable_cast` result types by replacing wildcard array +/// dimensions with concrete dimensions inferred from the operand type. +class UpdateUnifiableCastResultType final : public OpRewritePattern { + ConversionTracker &tracker_; + +public: + UpdateUnifiableCastResultType(MLIRContext *ctx, ConversionTracker &tracker) + : OpRewritePattern(ctx, 3), tracker_(tracker) {} + + LogicalResult matchAndRewrite(UnifiableCastOp op, PatternRewriter &rewriter) const override { + std::optional refinedResultTy = + refineCastResultArrayWildcards(op.getResult().getType(), op.getInput().getType()); + if (!refinedResultTy.has_value() || *refinedResultTy == op.getResult().getType()) { + return failure(); + } + if (!tracker_.isLegalConversion( + op.getResult().getType(), *refinedResultTy, "UpdateUnifiableCastResultType" + )) { + return failure(); + } + rewriter.modifyOpInPlace(op, [&]() { op.getResult().setType(*refinedResultTy); }); + return success(); + } +}; + +LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) { + MLIRContext *ctx = modOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx, tracker); + return applyAndFoldGreedily(modOp, tracker, std::move(patterns)); +} + +} // namespace CastRefinement + +namespace WildcardFunctionSpecialization { + +static SymbolRefAttr replaceLeafReference(SymbolRefAttr symRef, StringRef newLeafName) { + SmallVector pieces = getPieces(symRef); + assert(!pieces.empty() && "symbol reference must have at least one piece"); + pieces.back() = FlatSymbolRefAttr::get(StringAttr::get(symRef.getContext(), newLeafName)); + return asSymbolRefAttr(pieces); +} + +static std::string +buildWildcardSpecializationName(StringRef baseName, const WildcardArraySpecializationInfo &info) { + return BuildShortTypeString::from(baseName.str(), info.getConcreteTypeAttrs()); +} + +/// Retargets calls nested inside a specialized struct body to the corresponding +/// specialized struct member definition. +class CallStructFuncPattern : public OpConversionPattern { +public: + CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx) + : OpConversionPattern(converter, ctx, /*benefit=*/1) {} + + LogicalResult matchAndRewrite( + CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter + ) const override { + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) { + return op->emitError("Could not convert Op result types."); + } + + SymbolRefAttr calleeAttr = op.getCalleeAttr(); + if (op.calleeIsStructCompute()) { + if (StructType newStTy = getIfSingleton(newResultTypes)) { + calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference()); + } + } else if (op.calleeIsStructConstrain()) { + if (StructType newStTy = getAtIndex(adapter.getArgOperands().getTypes(), 0)) { + calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference()); + } + } + + replaceOpWithNewOp( + rewriter, op, newResultTypes, calleeAttr, adapter.getMapOperands(), + op.getNumDimsPerMapAttr(), adapter.getArgOperands() + ); + return success(); + } +}; + +/// Updates struct member declarations after the surrounding struct has been +/// specialized to concrete wildcard array types. +class MemberDefOpPattern : public OpConversionPattern { +public: + MemberDefOpPattern(TypeConverter &converter, MLIRContext *ctx) + : OpConversionPattern(converter, ctx, /*benefit=*/1) {} + + LogicalResult + matchAndRewrite(MemberDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override { + Type oldMemberType = op.getType(); + Type newMemberType = getTypeConverter()->convertType(oldMemberType); + if (oldMemberType == newMemberType) { + return failure(); + } + rewriter.modifyOpInPlace(op, [&op, &newMemberType]() { op.setType(newMemberType); }); + return success(); + } +}; + +static LogicalResult verifyNestedCallSymbols(FuncDefOp func) { + SymbolTableCollection tables; + WalkResult result = func.walk([&tables](CallOp nestedCall) { + return WalkResult(nestedCall.verifySymbolUses(tables)); + }); + return failure(result.wasInterrupted()); +} + +static LogicalResult applyWildcardSpecializationConversions( + FuncDefOp newFunc, const WildcardArraySpecializationInfo &info +) { + MLIRContext *ctx = newFunc.getContext(); + WildcardArrayTypeConverter tyConv(info.replacements); + ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx); + RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target); + if (failed(applyFullConversion(newFunc, target, std::move(patterns)))) { + return failure(); + } + return verifyNestedCallSymbols(newFunc); +} + +static LogicalResult applyWildcardSpecializationConversions( + FuncDefOp newFunc, FunctionType newFuncTy, const WildcardArraySpecializationInfo &info +) { + updateFuncSignature(newFunc, newFuncTy); + if (!info.hasConflictingReplacements) { + return applyWildcardSpecializationConversions(newFunc, info); + } + return verifyNestedCallSymbols(newFunc); +} + +static LogicalResult applyWildcardSpecializationConversions( + StructDefOp newStruct, StructType oldStructType, StructType newStructType, + const WildcardArraySpecializationInfo &info +) { + MLIRContext *ctx = newStruct.getContext(); + WildcardArrayTypeConverter tyConv(info.replacements, oldStructType, newStructType); + ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx); + RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target); + patterns.add(tyConv, ctx); + return applyFullConversion(newStruct, target, std::move(patterns)); +} + +static FailureOr getOrCreateSpecializedFreeFunc( + CallOp op, PatternRewriter &rewriter, SymbolTableCollection &symTables, FuncDefOp targetFunc, + const WildcardArraySpecializationInfo &info, FunctionType callSig +) { + ModuleOp parentModule = getParentOfType(targetFunc); + assert(parentModule && "free function must be nested in a module"); + + std::string newFuncName = buildWildcardSpecializationName(targetFunc.getSymName(), info); + FuncDefOp newFunc; + if (Operation *existing = symTables.getSymbolTable(parentModule).lookup(newFuncName)) { + newFunc = llvm::dyn_cast(existing); + if (!newFunc) { + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag.append("specialized function name collision for '", newFuncName, '\''); + }); + } + } else { + newFunc = targetFunc.clone(); + newFunc.setSymName(newFuncName); + symTables.getSymbolTable(parentModule).insert(newFunc, Block::iterator(targetFunc)); + if (failed(applyWildcardSpecializationConversions(newFunc, callSig, info))) { + newFunc.erase(); + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag.append("failure while creating wildcard-specialized function '", newFuncName, '\''); + }); + } + } + + return replaceLeafReference(op.getCalleeAttr(), newFunc.getSymName()); +} + +static FailureOr getOrCreateSpecializedStruct( + CallOp op, PatternRewriter &rewriter, SymbolTableCollection &symTables, + StructDefOp targetStruct, const WildcardArraySpecializationInfo &info, + ConversionTracker &tracker +) { + ModuleOp parentModule = getParentOfType(targetStruct); + assert(parentModule && "struct definition must be nested in a module"); + + StructType oldStructType = targetStruct.getType(); + std::string newStructName = buildWildcardSpecializationName(targetStruct.getSymName(), info); + StructDefOp newStruct; + if (Operation *existing = symTables.getSymbolTable(parentModule).lookup(newStructName)) { + newStruct = llvm::dyn_cast(existing); + if (!newStruct) { + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag.append("specialized struct name collision for '", newStructName, '\''); + }); + } + } else { + newStruct = targetStruct.clone(); + newStruct.setSymName(newStructName); + symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(targetStruct)); + StructType newStructType = newStruct.getType(); + if (failed( + applyWildcardSpecializationConversions(newStruct, oldStructType, newStructType, info) + )) { + newStruct.erase(); + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag.append("failure while creating wildcard-specialized struct '", newStructName, '\''); + }); + } + } + + StructType newStructType = newStruct.getType(); + tracker.recordSpecialization(oldStructType, newStructType); + return newStructType; +} + +/// Specializes calls whose target signature still contains wildcard arrays but +/// whose call-site signature has become concrete enough to resolve them. +class SpecializeWildcardCallOp final : public OpRewritePattern { + ConversionTracker &tracker_; + +public: + SpecializeWildcardCallOp(MLIRContext *ctx, ConversionTracker &tracker) + : OpRewritePattern(ctx), tracker_(tracker) {} + + LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override { + if (calleeReferencesTemplateParam(op)) { + return failure(); + } + + SymbolTableCollection symTables; + FailureOr> targetRes = op.getCalleeTarget(symTables); + if (failed(targetRes)) { + return failure(); + } + FuncDefOp targetFunc = targetRes->get(); + if (llvm::isa(targetFunc->getParentOp())) { + return failure(); + } + + FunctionType targetSig = targetFunc.getFunctionType(); + FunctionType callSig = op.getTypeSignature(); + StructDefOp targetStruct = + targetFunc.isInStruct() ? getParentOfType(targetFunc) : StructDefOp(); + std::optional ignoredStructType = + targetStruct ? std::optional(targetStruct.getType()) : std::nullopt; + if (!containsWildcardArrayDims(targetSig) || + !functionTypeIsMoreConcrete( + targetSig, callSig, tracker_, "SpecializeWildcardCallOp", ignoredStructType + )) { + return failure(); + } + + WildcardArraySpecializationInfo info; + if (failed(collectWildcardArraySpecializations(targetSig, callSig, info, ignoredStructType)) || + info.empty()) { + return failure(); + } + + if (!targetFunc.isInStruct()) { + FailureOr newCalleeAttr = + getOrCreateSpecializedFreeFunc(op, rewriter, symTables, targetFunc, info, callSig); + if (failed(newCalleeAttr)) { + return failure(); + } + SmallVector newResultTypes; + FailureOr> specializedFuncRes = + lookupTopLevelSymbol(symTables, *newCalleeAttr, op); + if (failed(specializedFuncRes)) { + return failure(); + } + newResultTypes.append( + specializedFuncRes->get().getFunctionType().getResults().begin(), + specializedFuncRes->get().getFunctionType().getResults().end() + ); + if (!tracker_.areLegalConversions( + op.getResultTypes(), newResultTypes, "SpecializeWildcardCallOp" + )) { + return failure(); + } + tracker_.recordInstantiation(op.getCalleeAttr()); + tracker_.updateModifiedFlag(true); + replaceOpWithNewOp( + rewriter, op, TypeRange(newResultTypes), *newCalleeAttr, + CallOp::toVectorOfValueRange(op.getMapOperands()), op.getNumDimsPerMapAttr(), + op.getArgOperands() + ); + return success(); + } + + assert(targetStruct && "struct function must have a parent struct"); + if (llvm::isa(targetStruct->getParentOp())) { + return failure(); + } + + StructType targetStructType = targetStruct.getType(); + SmallVector newResultTypes; + if (targetFunc.nameIsConstrain()) { + StructType selfType = getAtIndex(op.getArgOperands().getTypes(), 0); + if (!selfType) { + return failure(); + } + + std::string newStructName = buildWildcardSpecializationName(targetStruct.getSymName(), info); + SymbolRefAttr expectedSelfNameRef = + replaceLeafReference(targetStructType.getNameRef(), newStructName); + if (selfType.getNameRef() != expectedSelfNameRef) { + return failure(); + } + } + + FailureOr newStructTypeRes = + getOrCreateSpecializedStruct(op, rewriter, symTables, targetStruct, info, tracker_); + if (failed(newStructTypeRes)) { + return failure(); + } + StructType newStructType = *newStructTypeRes; + + if (!targetFunc.nameIsConstrain()) { + WildcardArrayTypeConverter tyConv(info.replacements, targetStructType, newStructType); + if (failed(tyConv.convertTypes(op.getResultTypes(), newResultTypes))) { + return failure(); + } + if (!tracker_.areLegalConversions( + op.getResultTypes(), newResultTypes, "SpecializeWildcardCallOp" + )) { + return failure(); + } + } + + tracker_.updateModifiedFlag(true); + SymbolRefAttr newCalleeAttr = + appendLeaf(newStructType.getNameRef(), op.getCallee().getLeafReference()); + replaceOpWithNewOp( + rewriter, op, TypeRange(newResultTypes), newCalleeAttr, + CallOp::toVectorOfValueRange(op.getMapOperands()), op.getNumDimsPerMapAttr(), + op.getArgOperands() + ); + return success(); + } +}; + +/// Rebinds `constrain` calls onto a previously-created specialized struct when +/// the receiver type already points at that specialized definition. +class RetargetStructConstrainCall final : public OpRewritePattern { + ConversionTracker &tracker_; + +public: + RetargetStructConstrainCall(MLIRContext *ctx, ConversionTracker &tracker) + : OpRewritePattern(ctx), tracker_(tracker) {} + + LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override { + if (!op.calleeIsStructConstrain() || op.getArgOperands().empty()) { + return failure(); + } + + StructType selfType = llvm::dyn_cast(op.getArgOperands().front().getType()); + if (!selfType) { + return failure(); + } + + SymbolTableCollection symTables; + FailureOr> targetRes = op.getCalleeTarget(symTables); + if (failed(targetRes)) { + return failure(); + } + FuncDefOp targetFunc = targetRes->get(); + StructDefOp targetStruct = getParentOfType(targetFunc); + if (!targetStruct || selfType == targetStruct.getType()) { + return failure(); + } + + SymbolRefAttr newCalleeAttr = + appendLeaf(selfType.getNameRef(), op.getCallee().getLeafReference()); + FailureOr> specializedRes = + lookupTopLevelSymbol(symTables, newCalleeAttr, op); + if (failed(specializedRes)) { + return failure(); + } + + tracker_.updateModifiedFlag(true); + replaceOpWithNewOp( + rewriter, op, TypeRange(op.getResultTypes()), newCalleeAttr, + CallOp::toVectorOfValueRange(op.getMapOperands()), op.getNumDimsPerMapAttr(), + op.getArgOperands() + ); + return success(); + } +}; + +LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) { + MLIRContext *ctx = modOp.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx, tracker); + patterns.add(ctx, tracker); + MatchFailureListener failureListener; + walkAndApplyPatterns(modOp, std::move(patterns), &failureListener); + return failure(failureListener.hadFailure); +} + +} // namespace WildcardFunctionSpecialization + +/// Drives wildcard-array cast refinement and callable specialization until the +/// module reaches a fixpoint, then cleans up replaced parameterized symbols. +class PassImpl : public llzk::polymorphic::impl::WildcardArraySpecializationPassBase { +public: + using Base = WildcardArraySpecializationPassBase; + using Base::Base; + +private: + void runOnOperation() override { + ModuleOp modOp = getOperation(); + ConversionTracker tracker; + unsigned loopCount = 0; + do { + ++loopCount; + if (loopCount > iterationLimit) { + llvm::errs() << DEBUG_TYPE << " exceeded the limit of " << iterationLimit + << " iterations!\n"; + signalPassFailure(); + return; + } + tracker.resetModifiedFlag(); + + if (failed(CastRefinement::run(modOp, tracker))) { + llvm::errs() << DEBUG_TYPE << " failed while refining wildcard array cast results\n"; + signalPassFailure(); + return; + } + if (failed(WildcardFunctionSpecialization::run(modOp, tracker))) { + llvm::errs() << DEBUG_TYPE + << " failed while specializing wildcard-array function signatures\n"; + signalPassFailure(); + return; + } + } while (tracker.isModified()); + + if (failed(erasePreimageOfInstantiations( + modOp, tracker, getAnalysis(), getAnalysis() + ))) { + signalPassFailure(); + } + } +}; + +} // namespace diff --git a/test/Transforms/Flattening/instantiate_wildcard.llzk b/test/Transforms/Flattening/instantiate_wildcard.llzk new file mode 100644 index 0000000000..1351fdcc80 --- /dev/null +++ b/test/Transforms/Flattening/instantiate_wildcard.llzk @@ -0,0 +1,252 @@ +// RUN: llzk-opt -split-input-file -llzk-flatten -verify-diagnostics %s | FileCheck --enable-var-scope %s + +module attributes {llzk.lang} { + poly.template @FRT { + poly.param @T_return : !poly.tvar<@T_return> + function.def @free_ret() -> !poly.tvar<@T_return> { + %0 = felt.const 0 : <"bn128"> + %1 = poly.unifiable_cast %0 : (!felt.type<"bn128">) -> !poly.tvar<@T_return> + function.return %1 : !poly.tvar<@T_return> + } + } + + poly.template @synth { + poly.param @T_return : !poly.tvar<@T_return> + function.def @synth() { + %0 = function.call @FRT::@free_ret() : () -> !poly.tvar<@T_return> + function.return + } + } + + poly.template @ComponentC { + struct.def @ComponentC { + function.def @compute() -> !struct.type<@ComponentC::@ComponentC<[]>> { + %self = struct.new : !struct.type<@ComponentC::@ComponentC<[]>> + function.call @synth::@synth<[?]>() : () -> () + function.return %self : !struct.type<@ComponentC::@ComponentC<[]>> + } + + function.def @constrain(%self: !struct.type<@ComponentC::@ComponentC<[]>>) { function.return } + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @FRT_f_free_ret() -> !felt.type<"bn128"> { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_0]] : (!felt.type<"bn128">) -> !felt.type<"bn128"> +// CHECK-NEXT: function.return %[[VAL_1]] : !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @synth_f_synth() { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = function.call @FRT_f_free_ret() : () -> !felt.type<"bn128"> +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: module @ComponentC { +// CHECK-NEXT: struct.def @ComponentC { +// CHECK-NEXT: function.def @compute() -> !struct.type<@ComponentC::@ComponentC> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = struct.new : <@ComponentC::@ComponentC> +// CHECK-NEXT: function.call @synth_f_synth() : () -> () +// CHECK-NEXT: function.return %[[VAL_3]] : !struct.type<@ComponentC::@ComponentC> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !struct.type<@ComponentC::@ComponentC>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang = "circom", llzk.main = !struct.type<@CallDiffTypeTest::@CallDiffTypeTest<[]>>} { + poly.template @T { + poly.param @T_arg0 : !poly.tvar<@T_arg0> + poly.param @T_return : !poly.tvar<@T_return> + function.def @peel2ranks(%arg0: !poly.tvar<@T_arg0> {function.arg_name = "a"}) -> !poly.tvar<@T_return> attributes {function.allow_non_native_field_ops} { + %felt_const_0 = felt.const 0 : <"bn128"> + %0 = cast.toindex %felt_const_0 : !felt.type<"bn128"> + %felt_const_0_0 = felt.const 0 : <"bn128"> + %1 = cast.toindex %felt_const_0_0 : !felt.type<"bn128"> + %2 = poly.unifiable_cast %arg0 : (!poly.tvar<@T_arg0>) -> !array.type> + %3 = array.read %2[%0, %1] : >, !poly.tvar<@"$e"> + %4 = poly.unifiable_cast %3 : (!poly.tvar<@"$e">) -> !poly.tvar<@T_return> + function.return %4 : !poly.tvar<@T_return> + } + poly.param @"$e" : !poly.tvar<@"$e"> + } + poly.template @CallDiffTypeTest { + struct.def @CallDiffTypeTest { + struct.member @outA : !array.type<5 x !felt.type<"bn128">> {llzk.pub} + struct.member @outB : !felt.type<"bn128"> {llzk.pub} + function.def @compute( + %arg0: !array.type<10,5,5 x !felt.type<"bn128">>, + %arg1: !array.type<10,5 x !felt.type<"bn128">>) + -> !struct.type<@CallDiffTypeTest::@CallDiffTypeTest<[]>> + attributes {function.allow_non_native_field_ops} { + %self = struct.new : <@CallDiffTypeTest::@CallDiffTypeTest<[]>> + %0 = function.call @T::@peel2ranks<[?, ?, ?]>(%arg0) : (!array.type<10,5,5 x !felt.type<"bn128">>) -> !array.type<5 x !felt.type<"bn128">> + struct.writem %self[@outA] = %0 : <@CallDiffTypeTest::@CallDiffTypeTest<[]>>, !array.type<5 x !felt.type<"bn128">> + %1 = function.call @T::@peel2ranks<[?, ?, ?]>(%arg1) : (!array.type<10,5 x !felt.type<"bn128">>) -> !felt.type<"bn128"> + struct.writem %self[@outB] = %1 : <@CallDiffTypeTest::@CallDiffTypeTest<[]>>, !felt.type<"bn128"> + function.return %self : !struct.type<@CallDiffTypeTest::@CallDiffTypeTest<[]>> + } + function.def @constrain( + %arg0: !struct.type<@CallDiffTypeTest::@CallDiffTypeTest<[]>>, + %arg1: !array.type<10,5,5 x !felt.type<"bn128">>, + %arg2: !array.type<10,5 x !felt.type<"bn128">>) + attributes {function.allow_non_native_field_ops} { + %0 = struct.readm %arg0[@outA] : <@CallDiffTypeTest::@CallDiffTypeTest<[]>>, !array.type<5 x !felt.type<"bn128">> + %1 = struct.readm %arg0[@outB] : <@CallDiffTypeTest::@CallDiffTypeTest<[]>>, !felt.type<"bn128"> + %2 = function.call @T::@peel2ranks<[?, ?, ?]>(%arg1) : (!array.type<10,5,5 x !felt.type<"bn128">>) -> !array.type<5 x !felt.type<"bn128">> + constrain.eq %0, %2 : !array.type<5 x !felt.type<"bn128">>, !array.type<5 x !felt.type<"bn128">> + %3 = function.call @T::@peel2ranks<[?, ?, ?]>(%arg2) : (!array.type<10,5 x !felt.type<"bn128">>) -> !felt.type<"bn128"> + constrain.eq %1, %3 : !felt.type<"bn128">, !felt.type<"bn128"> + function.return + } + } + } +} +// CHECK-LABEL: module attributes {llzk.lang = "circom", llzk.main = !struct.type<@CallDiffTypeTest::@CallDiffTypeTest<[]>>} { +// CHECK-NEXT: function.def @"T_!a_!a_!a_peel2ranks"(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<10,5,5 x !felt.type<"bn128">> {function.arg_name = "a"}) -> !array.type<5 x !felt.type<"bn128">> attributes {function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_1]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_1]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_0]] : (!array.type<10,5,5 x !felt.type<"bn128">>) -> !array.type> +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = array.extract %[[VAL_4]]{{\[}}%[[VAL_2]], %[[VAL_3]]] : > +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_5]] : (!array.type<5 x !felt.type<"bn128">>) -> !array.type<5 x !felt.type<"bn128">> +// CHECK-NEXT: function.return %[[VAL_6]] : !array.type<5 x !felt.type<"bn128">> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @"T_!a_f_f_peel2ranks"(%[[VAL_7:[0-9a-zA-Z_\.]+]]: !array.type<10,5 x !felt.type<"bn128">> {function.arg_name = "a"}) -> !felt.type<"bn128"> attributes {function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_8]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = cast.toindex %[[VAL_8]] : !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_7]] : (!array.type<10,5 x !felt.type<"bn128">>) -> !array.type> +// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_11]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : >, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_12]] : (!felt.type<"bn128">) -> !felt.type<"bn128"> +// CHECK-NEXT: function.return %[[VAL_13]] : !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: module @CallDiffTypeTest { +// CHECK-NEXT: struct.def @CallDiffTypeTest { +// CHECK-NEXT: struct.member @outA : !array.type<5 x !felt.type<"bn128">> {llzk.pub} +// CHECK-NEXT: struct.member @outB : !felt.type<"bn128"> {llzk.pub} +// CHECK-NEXT: function.def @compute(%[[VAL_14:[0-9a-zA-Z_\.]+]]: !array.type<10,5,5 x !felt.type<"bn128">>, %[[VAL_15:[0-9a-zA-Z_\.]+]]: !array.type<10,5 x !felt.type<"bn128">>) -> !struct.type<@CallDiffTypeTest::@CallDiffTypeTest> attributes {function.allow_non_native_field_ops, function.allow_witness} { +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = struct.new : <@CallDiffTypeTest::@CallDiffTypeTest> +// CHECK-NEXT: %[[VAL_17:[0-9a-zA-Z_\.]+]] = function.call @"T_!a_!a_!a_peel2ranks"(%[[VAL_14]]) : (!array.type<10,5,5 x !felt.type<"bn128">>) -> !array.type<5 x !felt.type<"bn128">> +// CHECK-NEXT: struct.writem %[[VAL_16]][@outA] = %[[VAL_17]] : <@CallDiffTypeTest::@CallDiffTypeTest>, !array.type<5 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = function.call @"T_!a_f_f_peel2ranks"(%[[VAL_15]]) : (!array.type<10,5 x !felt.type<"bn128">>) -> !felt.type<"bn128"> +// CHECK-NEXT: struct.writem %[[VAL_16]][@outB] = %[[VAL_18]] : <@CallDiffTypeTest::@CallDiffTypeTest>, !felt.type<"bn128"> +// CHECK-NEXT: function.return %[[VAL_16]] : !struct.type<@CallDiffTypeTest::@CallDiffTypeTest> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_19:[0-9a-zA-Z_\.]+]]: !struct.type<@CallDiffTypeTest::@CallDiffTypeTest>, %[[VAL_20:[0-9a-zA-Z_\.]+]]: !array.type<10,5,5 x !felt.type<"bn128">>, %[[VAL_21:[0-9a-zA-Z_\.]+]]: !array.type<10,5 x !felt.type<"bn128">>) attributes {function.allow_constraint, function.allow_non_native_field_ops} { +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_19]][@outA] : <@CallDiffTypeTest::@CallDiffTypeTest>, !array.type<5 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_19]][@outB] : <@CallDiffTypeTest::@CallDiffTypeTest>, !felt.type<"bn128"> +// CHECK-NEXT: %[[VAL_24:[0-9a-zA-Z_\.]+]] = function.call @"T_!a_!a_!a_peel2ranks"(%[[VAL_20]]) : (!array.type<10,5,5 x !felt.type<"bn128">>) -> !array.type<5 x !felt.type<"bn128">> +// CHECK-NEXT: constrain.eq %[[VAL_22]], %[[VAL_24]] : !array.type<5 x !felt.type<"bn128">>, !array.type<5 x !felt.type<"bn128">> +// CHECK-NEXT: %[[VAL_25:[0-9a-zA-Z_\.]+]] = function.call @"T_!a_f_f_peel2ranks"(%[[VAL_21]]) : (!array.type<10,5 x !felt.type<"bn128">>) -> !felt.type<"bn128"> +// CHECK-NEXT: constrain.eq %[[VAL_23]], %[[VAL_25]] : !felt.type<"bn128">, !felt.type<"bn128"> +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + poly.template @LaterConcrete { + poly.param @A : !poly.tvar<@A> + poly.param @B : !poly.tvar<@B> + + function.def @bridge(%arg0: !poly.tvar<@B>) { + %0 = poly.unifiable_cast %arg0 : (!poly.tvar<@B>) -> !poly.tvar<@A> + function.return + } + } + + poly.template @UseLaterConcrete { + struct.def @UseLaterConcrete { + function.def @compute(%arg0: !felt.type<"bn128">) -> !struct.type<@UseLaterConcrete::@UseLaterConcrete<[]>> { + %self = struct.new : !struct.type<@UseLaterConcrete::@UseLaterConcrete<[]>> + function.call @LaterConcrete::@bridge<[?, !felt.type<"bn128">]>(%arg0) : (!felt.type<"bn128">) -> () + function.return %self : !struct.type<@UseLaterConcrete::@UseLaterConcrete<[]>> + } + + function.def @constrain( + %self: !struct.type<@UseLaterConcrete::@UseLaterConcrete<[]>>, + %arg0: !felt.type<"bn128">) { + function.call @LaterConcrete::@bridge<[?, !felt.type<"bn128">]>(%arg0) : (!felt.type<"bn128">) -> () + function.return + } + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @LaterConcrete_f_f_bridge(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: module @UseLaterConcrete { +// CHECK-NEXT: struct.def @UseLaterConcrete { +// CHECK-NEXT: function.def @compute(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) -> !struct.type<@UseLaterConcrete::@UseLaterConcrete> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@UseLaterConcrete::@UseLaterConcrete> +// CHECK-NEXT: function.call @LaterConcrete_f_f_bridge(%[[VAL_1]]) : (!felt.type<"bn128">) -> () +// CHECK-NEXT: function.return %[[VAL_2]] : !struct.type<@UseLaterConcrete::@UseLaterConcrete> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_3:[0-9a-zA-Z_\.]+]]: !struct.type<@UseLaterConcrete::@UseLaterConcrete>, %[[VAL_4:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) attributes {function.allow_constraint} { +// CHECK-NEXT: function.call @LaterConcrete_f_f_bridge(%[[VAL_4]]) : (!felt.type<"bn128">) -> () +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + struct.def @CallOuterExplicitReturn { + function.def @compute(%arg0: !felt.type<"bn128">) -> !struct.type<@CallOuterExplicitReturn> { + %self = struct.new : !struct.type<@CallOuterExplicitReturn> + function.call @OuterExplicitReturn::@wrap<[?]>(%arg0) : (!felt.type<"bn128">) -> () + function.return %self : !struct.type<@CallOuterExplicitReturn> + } + + function.def @constrain(%self: !struct.type<@CallOuterExplicitReturn>, %arg0: !felt.type<"bn128">) { + function.call @OuterExplicitReturn::@wrap<[?]>(%arg0) : (!felt.type<"bn128">) -> () + function.return + } + } + + poly.template @InnerExplicitReturn { + poly.param @Ret : !poly.tvar<@Ret> + poly.param @Arg : !poly.tvar<@Arg> + + function.def @ret(%arg0: !poly.tvar<@Arg>) -> !poly.tvar<@Ret> { + %0 = poly.unifiable_cast %arg0 : (!poly.tvar<@Arg>) -> !poly.tvar<@Ret> + function.return %0 : !poly.tvar<@Ret> + } + } + + poly.template @OuterExplicitReturn { + poly.param @T : !poly.tvar<@T> + + function.def @wrap(%arg0: !felt.type<"bn128">) { + %0 = function.call @InnerExplicitReturn::@ret<[!felt.type<"bn128">, ?]>(%arg0) : (!felt.type<"bn128">) -> !poly.tvar<@T> + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @CallOuterExplicitReturn { +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) -> !struct.type<@CallOuterExplicitReturn> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@CallOuterExplicitReturn> +// CHECK-NEXT: function.call @OuterExplicitReturn_f_wrap(%[[VAL_0]]) : (!felt.type<"bn128">) -> () +// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@CallOuterExplicitReturn> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@CallOuterExplicitReturn>, %[[VAL_3:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) attributes {function.allow_constraint} { +// CHECK-NEXT: function.call @OuterExplicitReturn_f_wrap(%[[VAL_3]]) : (!felt.type<"bn128">) -> () +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: function.def @InnerExplicitReturn_f_f_ret(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) -> !felt.type<"bn128"> { +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_4]] : (!felt.type<"bn128">) -> !felt.type<"bn128"> +// CHECK-NEXT: function.return %[[VAL_5]] : !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @OuterExplicitReturn_f_wrap(%[[VAL_6:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128">) { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = function.call @InnerExplicitReturn_f_f_ret(%[[VAL_6]]) : (!felt.type<"bn128">) -> !felt.type<"bn128"> +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/WildcardArraySpecialization/specialize_wildcard.llzk b/test/Transforms/WildcardArraySpecialization/specialize_wildcard.llzk new file mode 100644 index 0000000000..07ec565ce2 --- /dev/null +++ b/test/Transforms/WildcardArraySpecialization/specialize_wildcard.llzk @@ -0,0 +1,331 @@ +// RUN: llzk-opt -split-input-file -llzk-flatten -llzk-specialize-wildcard-arrays -verify-diagnostics %s | FileCheck --enable-var-scope %s + +module attributes {llzk.lang} { + poly.template @FRT { + poly.param @T_return : !poly.tvar<@T_return> + function.def @free_ret() -> !poly.tvar<@T_return> { + %0 = felt.const 0 : <"bn128"> + %1 = poly.unifiable_cast %0 : (!felt.type<"bn128">) -> !poly.tvar<@T_return> + function.return %1 : !poly.tvar<@T_return> + } + } + + poly.template @synth { + poly.param @T_return : !poly.tvar<@T_return> + function.def @synth() { + %0 = function.call @FRT::@free_ret() : () -> !poly.tvar<@T_return> + function.return + } + } + + poly.template @ComponentC { + struct.def @ComponentC { + function.def @compute() -> !struct.type<@ComponentC::@ComponentC<[]>> { + %self = struct.new : !struct.type<@ComponentC::@ComponentC<[]>> + function.call @synth::@synth<[?]>() : () -> () + function.return %self : !struct.type<@ComponentC::@ComponentC<[]>> + } + + function.def @constrain(%self: !struct.type<@ComponentC::@ComponentC<[]>>) { function.return } + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @FRT_f_free_ret() -> !felt.type<"bn128"> { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = felt.const 0 : <"bn128"> +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_0]] : (!felt.type<"bn128">) -> !felt.type<"bn128"> +// CHECK-NEXT: function.return %[[VAL_1]] : !felt.type<"bn128"> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @synth_f_synth() { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = function.call @FRT_f_free_ret() : () -> !felt.type<"bn128"> +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: module @ComponentC { +// CHECK-NEXT: struct.def @ComponentC { +// CHECK-NEXT: function.def @compute() -> !struct.type<@ComponentC::@ComponentC> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = struct.new : <@ComponentC::@ComponentC> +// CHECK-NEXT: function.call @synth_f_synth() : () -> () +// CHECK-NEXT: function.return %[[VAL_3]] : !struct.type<@ComponentC::@ComponentC> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !struct.type<@ComponentC::@ComponentC>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + struct.def @Log { + function.def @compute(%arg0: !string.type, %arg1: !array.type) -> !struct.type<@Log> { + %self = struct.new : !struct.type<@Log> + function.return %self : !struct.type<@Log> + } + function.def @constrain(%arg0: !struct.type<@Log>, %arg1: !string.type, %arg2: !array.type) { + function.return + } + } + + function.def private @Log$$extern(!string.type, !array.type) -> !felt.type attributes {extern} + + function.def @call_log_extern(%a: !array.type<1 x !felt.type>) -> !felt.type { + %0 = poly.unifiable_cast %a : (!array.type<1 x !felt.type>) -> !array.type + %1 = string.new "extern: " + %2 = function.call @Log$$extern(%1, %0) : (!string.type, !array.type) -> !felt.type + function.return %2 : !felt.type + } + + struct.def @ComponentD { + function.def @compute(%a: !array.type<1 x !felt.type>) -> !struct.type<@ComponentD> { + %self = struct.new : !struct.type<@ComponentD> + %1 = poly.unifiable_cast %a : (!array.type<1 x !felt.type>) -> !array.type + %2 = string.new "test: " + %3 = function.call @Log::@compute(%2, %1) : (!string.type, !array.type) -> !struct.type<@Log> + function.return %self : !struct.type<@ComponentD> + } + + function.def @constrain(%self: !struct.type<@ComponentD>, %a: !array.type<1 x !felt.type>) { function.return } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @"Log_!a" { +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !string.type, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !struct.type<@"Log_!a"> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@"Log_!a"> +// CHECK-NEXT: function.return %[[VAL_2]] : !struct.type<@"Log_!a"> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_3:[0-9a-zA-Z_\.]+]]: !struct.type<@"Log_!a">, %[[VAL_4:[0-9a-zA-Z_\.]+]]: !string.type, %[[VAL_5:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: function.def private @"Log$$extern_!a"(!string.type, !array.type<1 x !felt.type>) -> !felt.type attributes {extern} +// CHECK-NEXT: function.def @call_log_extern(%[[VAL_6:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !felt.type { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = string.new "extern: " +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_6]] : (!array.type<1 x !felt.type>) -> !array.type<1 x !felt.type> +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = function.call @"Log$$extern_!a"(%[[VAL_7]], %[[VAL_8]]) : (!string.type, !array.type<1 x !felt.type>) -> !felt.type +// CHECK-NEXT: function.return %[[VAL_9]] : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: struct.def @ComponentD { +// CHECK-NEXT: function.def @compute(%[[VAL_10:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !struct.type<@ComponentD> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = string.new "test: " +// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = struct.new : <@ComponentD> +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_10]] : (!array.type<1 x !felt.type>) -> !array.type<1 x !felt.type> +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = function.call @"Log_!a"::@compute(%[[VAL_11]], %[[VAL_13]]) : (!string.type, !array.type<1 x !felt.type>) -> !struct.type<@"Log_!a"> +// CHECK-NEXT: function.return %[[VAL_12]] : !struct.type<@ComponentD> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_15:[0-9a-zA-Z_\.]+]]: !struct.type<@ComponentD>, %[[VAL_16:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + // expected-warning@+1 {{Parameterized definition still has uses!}} + function.def @head(%arr: !array.type) -> !felt.type { + %cast = poly.unifiable_cast %arr : (!array.type) -> !array.type + %zero = felt.const 0 + function.return %zero : !felt.type + } + + function.def private @head$$extern(%msg: !string.type, %arr: !array.type) -> !felt.type attributes {extern} + + function.def @call_one_a(%arr: !array.type<1 x !felt.type>) -> !felt.type { + %0 = function.call @head(%arr) : (!array.type<1 x !felt.type>) -> !felt.type + function.return %0 : !felt.type + } + + function.def @call_one_b(%arr: !array.type<1 x !felt.type>) -> !felt.type { + %0 = function.call @head(%arr) : (!array.type<1 x !felt.type>) -> !felt.type + function.return %0 : !felt.type + } + + function.def @call_two(%arr: !array.type<2 x !felt.type>) -> !felt.type { + %0 = function.call @head(%arr) : (!array.type<2 x !felt.type>) -> !felt.type + function.return %0 : !felt.type + } + + function.def @call_unknown(%arr: !array.type) -> !felt.type { + %0 = function.call @head(%arr) : (!array.type) -> !felt.type + function.return %0 : !felt.type + } + + function.def @call_one_extern(%arr: !array.type<1 x !felt.type>) -> !felt.type { + %msg = string.new "extern: " + %0 = function.call @head$$extern(%msg, %arr) : (!string.type, !array.type<1 x !felt.type>) -> !felt.type + function.return %0 : !felt.type + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @"head_!a"(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !felt.type { +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = felt.const 0 +// CHECK-NEXT: function.return %[[VAL_1]] : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: function.def @"head_!a"(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>) -> !felt.type { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.const 0 +// CHECK-NEXT: function.return %[[VAL_3]] : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: function.def @head(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !array.type) -> !felt.type { +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = felt.const 0 +// CHECK-NEXT: function.return %[[VAL_5]] : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: function.def private @"head$$extern_!a"(!string.type, !array.type<1 x !felt.type>) -> !felt.type attributes {extern} +// CHECK-NEXT: function.def @call_one_a(%[[VAL_6:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !felt.type { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = function.call @"head_!a"(%[[VAL_6]]) : (!array.type<1 x !felt.type>) -> !felt.type +// CHECK-NEXT: function.return %[[VAL_7]] : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: function.def @call_one_b(%[[VAL_8:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !felt.type { +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = function.call @"head_!a"(%[[VAL_8]]) : (!array.type<1 x !felt.type>) -> !felt.type +// CHECK-NEXT: function.return %[[VAL_9]] : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: function.def @call_two(%[[VAL_10:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>) -> !felt.type { +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = function.call @"head_!a"(%[[VAL_10]]) : (!array.type<2 x !felt.type>) -> !felt.type +// CHECK-NEXT: function.return %[[VAL_11]] : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: function.def @call_unknown(%[[VAL_12:[0-9a-zA-Z_\.]+]]: !array.type) -> !felt.type { +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = function.call @head(%[[VAL_12]]) : (!array.type) -> !felt.type +// CHECK-NEXT: function.return %[[VAL_13]] : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: function.def @call_one_extern(%[[VAL_14:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !felt.type { +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = string.new "extern: " +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = function.call @"head$$extern_!a"(%[[VAL_15]], %[[VAL_14]]) : (!string.type, !array.type<1 x !felt.type>) -> !felt.type +// CHECK-NEXT: function.return %[[VAL_16]] : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + function.def @takes_two( + %arg0: !array.type, + %arg1: !array.type) { + function.return + } + + function.def @call_two_sizes( + %arg0: !array.type<1 x !felt.type>, + %arg1: !array.type<2 x !felt.type>) { + function.call @takes_two(%arg0, %arg1) : (!array.type<1 x !felt.type>, !array.type<2 x !felt.type>) -> () + function.return + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @"takes_two_!a_!a"(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>) { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: function.def @call_two_sizes(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>, %[[VAL_3:[0-9a-zA-Z_\.]+]]: !array.type<2 x !felt.type>) { +// CHECK-NEXT: function.call @"takes_two_!a_!a"(%[[VAL_2]], %[[VAL_3]]) : (!array.type<1 x !felt.type>, !array.type<2 x !felt.type>) -> () +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + function.def @id(%arr: !array.type) -> !array.type { + function.return %arr : !array.type + } + + function.def @call_id(%arr: !array.type<1 x !felt.type>) -> !array.type<1 x !felt.type> { + %0 = function.call @id(%arr) : (!array.type<1 x !felt.type>) -> !array.type + %1 = poly.unifiable_cast %0 : (!array.type) -> !array.type<1 x !felt.type> + function.return %1 : !array.type<1 x !felt.type> + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @"id_!a"(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !array.type<1 x !felt.type> { +// CHECK-NEXT: function.return %[[VAL_0]] : !array.type<1 x !felt.type> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @call_id(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !array.type<1 x !felt.type> { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = function.call @"id_!a"(%[[VAL_1]]) : (!array.type<1 x !felt.type>) -> !array.type<1 x !felt.type> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_2]] : (!array.type<1 x !felt.type>) -> !array.type<1 x !felt.type> +// CHECK-NEXT: function.return %[[VAL_3]] : !array.type<1 x !felt.type> +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + poly.template @TLog { + struct.def @Log { + function.def @compute(%msg: !string.type, %arr: !array.type) -> !struct.type<@TLog::@Log<[]>> { + %self = struct.new : !struct.type<@TLog::@Log<[]>> + function.return %self : !struct.type<@TLog::@Log<[]>> + } + + function.def @constrain(%self: !struct.type<@TLog::@Log<[]>>, %msg: !string.type, %arr: !array.type) { + function.return + } + } + } + + struct.def @UseTLog { + function.def @compute(%a: !array.type<1 x !felt.type>) -> !struct.type<@UseTLog> { + %self = struct.new : !struct.type<@UseTLog> + %arr = poly.unifiable_cast %a : (!array.type<1 x !felt.type>) -> !array.type + %msg = string.new "templated: " + %log = function.call @TLog::@Log::@compute(%msg, %arr) : (!string.type, !array.type) -> !struct.type<@TLog::@Log<[]>> + function.return %self : !struct.type<@UseTLog> + } + + function.def @constrain(%self: !struct.type<@UseTLog>, %a: !array.type<1 x !felt.type>) { + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: module @TLog { +// CHECK-NEXT: struct.def @"Log_!a" { +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !string.type, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !struct.type<@TLog::@"Log_!a"> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@TLog::@"Log_!a"> +// CHECK-NEXT: function.return %[[VAL_2]] : !struct.type<@TLog::@"Log_!a"> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_3:[0-9a-zA-Z_\.]+]]: !struct.type<@TLog::@"Log_!a">, %[[VAL_4:[0-9a-zA-Z_\.]+]]: !string.type, %[[VAL_5:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: struct.def @UseTLog { +// CHECK-NEXT: function.def @compute(%[[VAL_6:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) -> !struct.type<@UseTLog> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = string.new "templated: " +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = struct.new : <@UseTLog> +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = poly.unifiable_cast %[[VAL_6]] : (!array.type<1 x !felt.type>) -> !array.type<1 x !felt.type> +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = function.call @TLog::@"Log_!a"::@compute(%[[VAL_7]], %[[VAL_9]]) : (!string.type, !array.type<1 x !felt.type>) -> !struct.type<@TLog::@"Log_!a"> +// CHECK-NEXT: function.return %[[VAL_8]] : !struct.type<@UseTLog> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_11:[0-9a-zA-Z_\.]+]]: !struct.type<@UseTLog>, %[[VAL_12:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// ----- + +module attributes {llzk.lang} { + struct.def @Log { + function.def @compute(%msg: !string.type, %arr: !array.type) -> !struct.type<@Log> { + %self = struct.new : !struct.type<@Log> + function.return %self : !struct.type<@Log> + } + + function.def @constrain(%self: !struct.type<@Log>, %msg: !string.type, %arr: !array.type) { + function.return + } + } + + function.def @relay(%self: !struct.type<@Log>, %msg: !string.type, %arr: !array.type<1 x !felt.type>) attributes {function.allow_constraint} { + function.call @Log::@constrain(%self, %msg, %arr) : (!struct.type<@Log>, !string.type, !array.type<1 x !felt.type>) -> () + function.return + } +} +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NOT: struct.def @"Log_!a" +// CHECK-NEXT: struct.def @Log { +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !string.type, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !array.type) -> !struct.type<@Log> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@Log> +// CHECK-NEXT: function.return %[[VAL_2]] : !struct.type<@Log> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_3:[0-9a-zA-Z_\.]+]]: !struct.type<@Log>, %[[VAL_4:[0-9a-zA-Z_\.]+]]: !string.type, %[[VAL_5:[0-9a-zA-Z_\.]+]]: !array.type) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: function.def @relay(%[[VAL_6:[0-9a-zA-Z_\.]+]]: !struct.type<@Log>, %[[VAL_7:[0-9a-zA-Z_\.]+]]: !string.type, %[[VAL_8:[0-9a-zA-Z_\.]+]]: !array.type<1 x !felt.type>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.call @Log::@constrain(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) : (!struct.type<@Log>, !string.type, !array.type<1 x !felt.type>) -> () +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } From b1b8d52ca4e6114cdd9a80417f96342a9f9e8b6c Mon Sep 17 00:00:00 2001 From: project-llzk-release-bot Date: Thu, 18 Jun 2026 15:36:12 -0500 Subject: [PATCH 06/12] Release v2.1.2 (#563) * Setup pre-release files for release v2.1.2 * bump version number * cleanup changelog * Update pre-release counter for RC v2.1.2-rc1 * Finalize release v2.1.2 --------- Co-authored-by: tim-hoffman Co-authored-by: Tim Hoffman --- CHANGELOG.md | 28 +++++++++++++++++++ .../unreleased/array-new-cse-allocation.yaml | 4 --- ...-flatten-does-not-support-cleanup-for.yaml | 2 -- .../fix__poly-typed-aux-members.yaml | 4 --- .../iangneal__execution-engine-updates.yaml | 4 --- .../iangneal__implement-witgen-ops.yaml | 2 -- ...neal__verif-precondition-restrictions.yaml | 5 ---- .../lower-composite-poly-roots.yaml | 2 -- .../unreleased/poly-dynamic-aux-scope.yaml | 2 -- .../unreleased/th__agent_instructions.yaml | 2 -- .../th__allow_non_native_in_contract.yaml | 2 -- .../unreleased/th__cleanup_pass_ods.yaml | 0 .../unreleased/th__fix_313_and_465.yaml | 2 -- .../unreleased/th__instantiate_wildcard.yaml | 5 ---- .../unreleased/th__pod_op_interfaces.yaml | 3 -- .../unreleased/th__prevent_new_pod_cse.yaml | 2 -- .../unreleased/th__reusable_template.yaml | 0 .../th__update-agent-instructions.yaml | 0 .../unreleased/th__use_rdv_wrapper.yaml | 2 -- nix/llzk.nix | 2 +- 20 files changed, 29 insertions(+), 44 deletions(-) delete mode 100644 changelogs/unreleased/array-new-cse-allocation.yaml delete mode 100644 changelogs/unreleased/codex__github-mention-llzk-flatten-does-not-support-cleanup-for.yaml delete mode 100644 changelogs/unreleased/fix__poly-typed-aux-members.yaml delete mode 100644 changelogs/unreleased/iangneal__execution-engine-updates.yaml delete mode 100644 changelogs/unreleased/iangneal__implement-witgen-ops.yaml delete mode 100644 changelogs/unreleased/iangneal__verif-precondition-restrictions.yaml delete mode 100644 changelogs/unreleased/lower-composite-poly-roots.yaml delete mode 100644 changelogs/unreleased/poly-dynamic-aux-scope.yaml delete mode 100644 changelogs/unreleased/th__agent_instructions.yaml delete mode 100644 changelogs/unreleased/th__allow_non_native_in_contract.yaml delete mode 100644 changelogs/unreleased/th__cleanup_pass_ods.yaml delete mode 100644 changelogs/unreleased/th__fix_313_and_465.yaml delete mode 100644 changelogs/unreleased/th__instantiate_wildcard.yaml delete mode 100644 changelogs/unreleased/th__pod_op_interfaces.yaml delete mode 100644 changelogs/unreleased/th__prevent_new_pod_cse.yaml delete mode 100644 changelogs/unreleased/th__reusable_template.yaml delete mode 100644 changelogs/unreleased/th__update-agent-instructions.yaml delete mode 100644 changelogs/unreleased/th__use_rdv_wrapper.yaml diff --git a/CHANGELOG.md b/CHANGELOG.md index 57873f5516..0e6c9550a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,31 @@ +## v2.1.2 - 2026-06-18 +### Added +- `--llzk-remove-unused-discardable-allocations` for removing unread discardable allocations and their dead stores. +- `Destructurable*Interface` and `Promotable*Interface` to pod ops/types +- `PodRefOpInterface` and `PodAccessOpInterface` for ops that reference or access pods +- `llzk-pod-to-scalar` pass to destructure pod type values into scalar SSA values +- `llzk-specialize-wildcard-arrays` pass to refine array types with wildcard dimensions +- `--check-output` option for `llzk-witgen` +- AI coding agent instructions file + +### Changed +- Allow non-native field ops in `verif.contract` +- Apply `llzk-flatten` cleanup option to free functions, not just structs +- `verif.contract` now rejects preconditions (`verif.require_*`) derived from struct members or function return values +- `verif.contract` for the `llzk.main` struct now rejects direct `verif.require_*` ops as well + +### Fixed +- Added `scf.if` handling for `llzk-witgen --backend=execution-engine` +- Emit all R1CS-lowering auxiliary members with the exact type of the materialized expression +- Emit polynomial-lowering auxiliary members with the exact type of the materialized expression +- Emit synthesized zero R1CS linear-combination constants with a printable integer width +- Handle wildcard `CallOp` template parameters in the flattening pass. +- Implement missing felt ops for `llzk-witgen` +- Lower high-degree `felt.add`, `felt.sub`, and `felt.neg` equality roots and nonlinear struct constrain call arguments in the poly-lowering pass +- Prevent `--cse` from merging distinct `array.new` allocations or `pod.new` allocations. +- Reject polynomial and R1CS lowering on non-straight-line constrain bodies instead of materializing component-scope auxiliaries inside control flow +- Update SourceRefAnalysis to handle `scf.while`, `verif.contract`, and `verif.include` +- Array and pod scalarization passes now use the RDV pass wrapper with bug fix ## v2.1.1 - 2026-06-04 ### Fixed - Fixed prime field definitions diff --git a/changelogs/unreleased/array-new-cse-allocation.yaml b/changelogs/unreleased/array-new-cse-allocation.yaml deleted file mode 100644 index 9152909146..0000000000 --- a/changelogs/unreleased/array-new-cse-allocation.yaml +++ /dev/null @@ -1,4 +0,0 @@ -fixed: - - Prevent `--cse` from merging distinct `array.new` allocations. -added: - - Add `--llzk-remove-unused-discardable-allocations` for removing unread discardable allocations and their dead stores. diff --git a/changelogs/unreleased/codex__github-mention-llzk-flatten-does-not-support-cleanup-for.yaml b/changelogs/unreleased/codex__github-mention-llzk-flatten-does-not-support-cleanup-for.yaml deleted file mode 100644 index 5895179a66..0000000000 --- a/changelogs/unreleased/codex__github-mention-llzk-flatten-does-not-support-cleanup-for.yaml +++ /dev/null @@ -1,2 +0,0 @@ -changed: - - Apply `llzk-flatten` cleanup option to free functions, not just structs diff --git a/changelogs/unreleased/fix__poly-typed-aux-members.yaml b/changelogs/unreleased/fix__poly-typed-aux-members.yaml deleted file mode 100644 index d4561d7685..0000000000 --- a/changelogs/unreleased/fix__poly-typed-aux-members.yaml +++ /dev/null @@ -1,4 +0,0 @@ -fixed: - - Emit polynomial-lowering auxiliary members with the exact type of the materialized expression - - Emit all R1CS-lowering auxiliary members with the exact type of the materialized expression - - Emit synthesized zero R1CS linear-combination constants with a printable integer width diff --git a/changelogs/unreleased/iangneal__execution-engine-updates.yaml b/changelogs/unreleased/iangneal__execution-engine-updates.yaml deleted file mode 100644 index dc472fa347..0000000000 --- a/changelogs/unreleased/iangneal__execution-engine-updates.yaml +++ /dev/null @@ -1,4 +0,0 @@ -added: - - '`--check-output` option for `llzk-witgen`' -fixed: - - Added `scf.if` handling for `llzk-witgen --backend=execution-engine` diff --git a/changelogs/unreleased/iangneal__implement-witgen-ops.yaml b/changelogs/unreleased/iangneal__implement-witgen-ops.yaml deleted file mode 100644 index 1cbf274485..0000000000 --- a/changelogs/unreleased/iangneal__implement-witgen-ops.yaml +++ /dev/null @@ -1,2 +0,0 @@ -fixed: - - Implement missing felt ops for `llzk-witgen` diff --git a/changelogs/unreleased/iangneal__verif-precondition-restrictions.yaml b/changelogs/unreleased/iangneal__verif-precondition-restrictions.yaml deleted file mode 100644 index b177bbc0fb..0000000000 --- a/changelogs/unreleased/iangneal__verif-precondition-restrictions.yaml +++ /dev/null @@ -1,5 +0,0 @@ -changed: - - '`verif.contract` now rejects preconditions (`verif.require_*`) derived from struct members or function return values' - - '`verif.contract` for the `llzk.main` struct rejects direct `verif.require_*` ops as well' -fixed: - - Update SourceRefAnalysis to handle `scf.while`, `verif.contract`, and `verif.include` diff --git a/changelogs/unreleased/lower-composite-poly-roots.yaml b/changelogs/unreleased/lower-composite-poly-roots.yaml deleted file mode 100644 index 60d8e480c7..0000000000 --- a/changelogs/unreleased/lower-composite-poly-roots.yaml +++ /dev/null @@ -1,2 +0,0 @@ -fixed: - - Lower high-degree felt.add, felt.sub, and felt.neg equality roots and nonlinear struct constrain call arguments in the poly-lowering pass diff --git a/changelogs/unreleased/poly-dynamic-aux-scope.yaml b/changelogs/unreleased/poly-dynamic-aux-scope.yaml deleted file mode 100644 index 23c36cb813..0000000000 --- a/changelogs/unreleased/poly-dynamic-aux-scope.yaml +++ /dev/null @@ -1,2 +0,0 @@ -fixed: - - Reject polynomial and R1CS lowering on non-straight-line constrain bodies instead of materializing component-scope auxiliaries inside control flow diff --git a/changelogs/unreleased/th__agent_instructions.yaml b/changelogs/unreleased/th__agent_instructions.yaml deleted file mode 100644 index cc71a70e8d..0000000000 --- a/changelogs/unreleased/th__agent_instructions.yaml +++ /dev/null @@ -1,2 +0,0 @@ -added: - - AI coding agent instructions file diff --git a/changelogs/unreleased/th__allow_non_native_in_contract.yaml b/changelogs/unreleased/th__allow_non_native_in_contract.yaml deleted file mode 100644 index 8d1efe03a4..0000000000 --- a/changelogs/unreleased/th__allow_non_native_in_contract.yaml +++ /dev/null @@ -1,2 +0,0 @@ -changed: - - Allow non-native field ops in `verif.contract` diff --git a/changelogs/unreleased/th__cleanup_pass_ods.yaml b/changelogs/unreleased/th__cleanup_pass_ods.yaml deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/changelogs/unreleased/th__fix_313_and_465.yaml b/changelogs/unreleased/th__fix_313_and_465.yaml deleted file mode 100644 index 1076ba2042..0000000000 --- a/changelogs/unreleased/th__fix_313_and_465.yaml +++ /dev/null @@ -1,2 +0,0 @@ -added: - - '`llzk-pod-to-scalar` pass to destructure pod type values into scalar SSA values' diff --git a/changelogs/unreleased/th__instantiate_wildcard.yaml b/changelogs/unreleased/th__instantiate_wildcard.yaml deleted file mode 100644 index 027bdaef66..0000000000 --- a/changelogs/unreleased/th__instantiate_wildcard.yaml +++ /dev/null @@ -1,5 +0,0 @@ -fixed: - - Handle wildcard CallOp template parameters in the flattening pass. - -added: - - '`llzk-specialize-wildcard-arrays` pass to refine array types with wildcard dimensions' diff --git a/changelogs/unreleased/th__pod_op_interfaces.yaml b/changelogs/unreleased/th__pod_op_interfaces.yaml deleted file mode 100644 index d02123a4eb..0000000000 --- a/changelogs/unreleased/th__pod_op_interfaces.yaml +++ /dev/null @@ -1,3 +0,0 @@ -added: - - PodRefOpInterface and PodAccessOpInterface for ops that reference or access pods - - Destructurable*Interface and Promotable*Interface to pod ops/types diff --git a/changelogs/unreleased/th__prevent_new_pod_cse.yaml b/changelogs/unreleased/th__prevent_new_pod_cse.yaml deleted file mode 100644 index d72273c8d6..0000000000 --- a/changelogs/unreleased/th__prevent_new_pod_cse.yaml +++ /dev/null @@ -1,2 +0,0 @@ -fixed: - - Prevent `--cse` from merging distinct `pod.new` allocations. diff --git a/changelogs/unreleased/th__reusable_template.yaml b/changelogs/unreleased/th__reusable_template.yaml deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/changelogs/unreleased/th__update-agent-instructions.yaml b/changelogs/unreleased/th__update-agent-instructions.yaml deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/changelogs/unreleased/th__use_rdv_wrapper.yaml b/changelogs/unreleased/th__use_rdv_wrapper.yaml deleted file mode 100644 index 5e9e24dfdc..0000000000 --- a/changelogs/unreleased/th__use_rdv_wrapper.yaml +++ /dev/null @@ -1,2 +0,0 @@ -fixed: - - array and pod scalarization passes used the original RDV pass instead of the wrapper with bug fix diff --git a/nix/llzk.nix b/nix/llzk.nix index 326b71d415..1831b05702 100644 --- a/nix/llzk.nix +++ b/nix/llzk.nix @@ -12,7 +12,7 @@ }: let - version = "2.1.1"; + version = "2.1.2"; in stdenv.mkDerivation { pname = "llzk-${lib.toLower cmakeBuildType}"; From 69eb5877a1e6084bb33b24b39f71f7479effa600 Mon Sep 17 00:00:00 2001 From: Timothy Hoffman <4001421+tim-hoffman@users.noreply.github.com> Date: Thu, 18 Jun 2026 16:08:33 -0500 Subject: [PATCH 07/12] Define common pass pipelines (#548) --- .../Transforms/TransformationPasses.h | 2 +- .../r1cs/include/r1cs/Dialect/IR/Attrs.td | 2 +- backends/r1cs/include/r1cs/Dialect/IR/Ops.td | 2 +- .../r1cs/include/r1cs/Dialect/IR/Types.td | 2 +- .../Transforms/TransformationPassPipelines.h | 21 + .../r1cs/Transforms/TransformationPasses.h | 4 +- backends/r1cs/lib/r1cs/Dialect/Dialect.cpp | 2 +- .../r1cs/lib/r1cs/Transforms/CMakeLists.txt | 2 +- .../lib/r1cs/Transforms/R1CSLoweringPass.cpp | 3 + .../TransformationPassPipelines.cpp | 36 +- changelogs/unreleased/th__pass_pipelines.yaml | 8 + include/CMakeLists.txt | 1 + include/llzk-c/Dialect/Struct.h | 1 + .../llzk/Analysis/MemberOverwriteAnalysis.h | 2 +- include/llzk/Analysis/SparseAnalysis.h | 2 +- .../Array/Transforms/TransformationPasses.h | 2 +- .../Include/Transforms/InlineIncludesPass.h | 2 +- .../POD/Transforms/TransformationPasses.h | 2 +- .../Transforms/TransformationPasses.h | 2 +- .../Transforms/TransformationPasses.td | 21 +- include/llzk/Dialect/SMT/IR/SMTAttributes.h | 2 +- .../Dialect/Struct/Transforms/CMakeLists.txt | 10 + .../Struct/Transforms/InlineStructsPass.h} | 2 +- .../Struct/Transforms/TransformationPasses.h | 20 + .../Struct/Transforms/TransformationPasses.td | 41 ++ .../LLZKTransformationPassPipelines.h | 130 +++++ .../Transforms/LLZKTransformationPasses.h | 6 +- .../Transforms/LLZKTransformationPasses.td | 26 - include/llzk/Transforms/Parsers.h | 167 ++++++ include/llzk/Util/SymbolTableLLZK.h | 2 +- .../llzk/Validators/LLZKValidationPasses.h | 2 +- .../LightweightSignalEquivalenceAnalysis.cpp | 2 +- lib/Analysis/SparseAnalysis.cpp | 3 +- lib/CAPI/Dialect/Array.cpp | 2 +- lib/CAPI/Dialect/Include.cpp | 2 +- lib/CAPI/Dialect/POD.cpp | 2 +- lib/CAPI/Dialect/Poly.cpp | 2 +- lib/CAPI/Dialect/Struct.cpp | 6 + lib/CAPI/Transforms.cpp | 2 +- lib/Dialect/Polymorphic/CMakeLists.txt | 12 +- .../Polymorphic/Transforms/FlatteningPass.cpp | 13 +- lib/Dialect/SMT/SMTAttributes.cpp | 2 +- lib/Dialect/SMT/SMTDialect.cpp | 2 +- lib/Dialect/SMT/SMTOps.cpp | 2 +- lib/Dialect/SMT/SMTTypes.cpp | 2 +- lib/Dialect/Struct/CMakeLists.txt | 2 +- .../Struct/Transforms/InlineStructsPass.cpp} | 28 +- lib/Transforms/CMakeLists.txt | 2 +- .../LLZKComputeConstrainToProductPass.cpp | 2 +- lib/Transforms/LLZKFuseProductLoopsPass.cpp | 2 +- lib/Transforms/LLZKLoweringUtils.cpp | 2 +- .../LLZKTransformationPassPipelines.cpp | 154 +++++- .../LLZKUnusedDeclarationEliminationPass.cpp | 46 +- nix/llzk.nix | 2 +- .../InlineStructs/circom_subcmps5b.llzk | 22 +- .../full_struct_inlining_cleanup_bug.llzk | 88 ++++ .../inline_structs_max_complexity.llzk | 496 +++++++++--------- .../InlineStructs/inline_structs_pass_2.llzk | 98 ++-- .../poly_lowering_fail_low_deg.llzk | 2 +- .../poly_lowering_fail_reserved_name.llzk | 2 +- .../PolyLowering/poly_lowering_pass_deg2.llzk | 281 +++++----- .../PolyLowering/poly_lowering_pass_deg3.llzk | 2 +- .../R1CSLowering/r1cs_lowering_pass.llzk | 4 +- .../r1cs_lowering_quadratic_linear_sign.llzk | 2 +- .../r1cs_lowering_typed_aux_member.llzk | 2 +- .../unused_decl_after_redundant_elim.llzk | 71 +++ .../unused_declaration_pass.llzk | 44 +- tools/llzk-opt/llzk-opt.cpp | 10 +- 68 files changed, 1310 insertions(+), 635 deletions(-) create mode 100644 backends/r1cs/include/r1cs/Transforms/TransformationPassPipelines.h create mode 100644 changelogs/unreleased/th__pass_pipelines.yaml create mode 100644 include/llzk/Dialect/Struct/Transforms/CMakeLists.txt rename include/llzk/{Transforms/LLZKInlineStructsPass.h => Dialect/Struct/Transforms/InlineStructsPass.h} (94%) create mode 100644 include/llzk/Dialect/Struct/Transforms/TransformationPasses.h create mode 100644 include/llzk/Dialect/Struct/Transforms/TransformationPasses.td create mode 100644 include/llzk/Transforms/LLZKTransformationPassPipelines.h rename lib/{Transforms/LLZKInlineStructsPass.cpp => Dialect/Struct/Transforms/InlineStructsPass.cpp} (97%) create mode 100644 test/Transforms/InlineStructs/full_struct_inlining_cleanup_bug.llzk create mode 100644 test/Transforms/RedundantAndUnusedElim/unused_decl_after_redundant_elim.llzk diff --git a/backends/pcl-conv/include/pcl-conv/Transforms/TransformationPasses.h b/backends/pcl-conv/include/pcl-conv/Transforms/TransformationPasses.h index f9cd3e60c2..64b6d2d7c2 100644 --- a/backends/pcl-conv/include/pcl-conv/Transforms/TransformationPasses.h +++ b/backends/pcl-conv/include/pcl-conv/Transforms/TransformationPasses.h @@ -19,4 +19,4 @@ namespace pcl::conversion { #define GEN_PASS_REGISTRATION #include "pcl-conv/Transforms/TransformationPasses.h.inc" -}; // namespace pcl::conversion +} // namespace pcl::conversion diff --git a/backends/r1cs/include/r1cs/Dialect/IR/Attrs.td b/backends/r1cs/include/r1cs/Dialect/IR/Attrs.td index 20f1bb10a1..d4ca8e4b3a 100644 --- a/backends/r1cs/include/r1cs/Dialect/IR/Attrs.td +++ b/backends/r1cs/include/r1cs/Dialect/IR/Attrs.td @@ -1,4 +1,4 @@ -//===-- Attrs.td -----------------------------------------*--- tablegen -*-===// +//===-- Attrs.td -------------------------------------------*- tablegen -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. diff --git a/backends/r1cs/include/r1cs/Dialect/IR/Ops.td b/backends/r1cs/include/r1cs/Dialect/IR/Ops.td index fa30150573..ea117761b3 100644 --- a/backends/r1cs/include/r1cs/Dialect/IR/Ops.td +++ b/backends/r1cs/include/r1cs/Dialect/IR/Ops.td @@ -1,4 +1,4 @@ -//===-- Ops.td -------------------------------------------*--- tablegen -*-===// +//===-- Ops.td ---------------------------------------------*- tablegen -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. diff --git a/backends/r1cs/include/r1cs/Dialect/IR/Types.td b/backends/r1cs/include/r1cs/Dialect/IR/Types.td index 239f8a0384..83812f7d7a 100644 --- a/backends/r1cs/include/r1cs/Dialect/IR/Types.td +++ b/backends/r1cs/include/r1cs/Dialect/IR/Types.td @@ -1,4 +1,4 @@ -//===-- Types.td -----------------------------------------*--- tablegen -*-===// +//===-- Types.td -------------------------------------------*- tablegen -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. diff --git a/backends/r1cs/include/r1cs/Transforms/TransformationPassPipelines.h b/backends/r1cs/include/r1cs/Transforms/TransformationPassPipelines.h new file mode 100644 index 0000000000..4ff5a124b8 --- /dev/null +++ b/backends/r1cs/include/r1cs/Transforms/TransformationPassPipelines.h @@ -0,0 +1,21 @@ +//===-- TransformationPassPipelines.h ---------------------------*- C++ -*-===// +// +// Part of the LLZK Project, under the Apache License v2.0. +// See LICENSE.txt for license information. +// Copyright 2026 Project LLZK +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace r1cs { + +void buildFullR1CSLoweringPipeline(mlir::OpPassManager &); + +void registerTransformationPassPipelines(); + +} // namespace r1cs diff --git a/backends/r1cs/include/r1cs/Transforms/TransformationPasses.h b/backends/r1cs/include/r1cs/Transforms/TransformationPasses.h index 7ac267a674..22969a902f 100644 --- a/backends/r1cs/include/r1cs/Transforms/TransformationPasses.h +++ b/backends/r1cs/include/r1cs/Transforms/TransformationPasses.h @@ -15,10 +15,8 @@ namespace r1cs { -void registerTransformationPassPipelines(); - #define GEN_PASS_DECL #define GEN_PASS_REGISTRATION #include "r1cs/Transforms/TransformationPasses.h.inc" -}; // namespace r1cs +} // namespace r1cs diff --git a/backends/r1cs/lib/r1cs/Dialect/Dialect.cpp b/backends/r1cs/lib/r1cs/Dialect/Dialect.cpp index e48c75d7f9..f360605dca 100644 --- a/backends/r1cs/lib/r1cs/Dialect/Dialect.cpp +++ b/backends/r1cs/lib/r1cs/Dialect/Dialect.cpp @@ -1,4 +1,4 @@ -//===-- Dialect.cpp - R1CS dialect implementation -----------*- C++ -*-----===// +//===-- Dialect.cpp - R1CS dialect implementation ---------------*- C++ -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. diff --git a/backends/r1cs/lib/r1cs/Transforms/CMakeLists.txt b/backends/r1cs/lib/r1cs/Transforms/CMakeLists.txt index 02dc7fa4e9..0f6679745a 100644 --- a/backends/r1cs/lib/r1cs/Transforms/CMakeLists.txt +++ b/backends/r1cs/lib/r1cs/Transforms/CMakeLists.txt @@ -3,7 +3,6 @@ add_library(R1CS::Transforms ALIAS R1CSTransforms) file(GLOB R1CSTransforms_SOURCES "*.cpp") - target_sources(R1CSTransforms PRIVATE ${R1CSTransforms_SOURCES}) target_link_libraries( @@ -15,6 +14,7 @@ target_link_libraries( MLIRIR MLIRPass MLIRParser MLIRTransformUtils MLIRSCFTransforms LLVMHeaders MLIRHeaders PRIVATE + LLZKTransforms LLZKDialect R1CSDialect LLZKAnalysis diff --git a/backends/r1cs/lib/r1cs/Transforms/R1CSLoweringPass.cpp b/backends/r1cs/lib/r1cs/Transforms/R1CSLoweringPass.cpp index 880b02f398..3fbd597dce 100644 --- a/backends/r1cs/lib/r1cs/Transforms/R1CSLoweringPass.cpp +++ b/backends/r1cs/lib/r1cs/Transforms/R1CSLoweringPass.cpp @@ -704,6 +704,9 @@ class PassImpl : public r1cs::impl::R1CSLoweringPassBase { buildAndEmitR1CS(moduleOp, structDef, constrainFunc, degreeMemo); structDef.erase(); }); + + // Remove `llzk.main` attribute because all structs were replaced with `r1cs.circuit` ops. + moduleOp->removeAttr(MAIN_ATTR_NAME); } }; diff --git a/backends/r1cs/lib/r1cs/Transforms/TransformationPassPipelines.cpp b/backends/r1cs/lib/r1cs/Transforms/TransformationPassPipelines.cpp index 7d6277fefe..5b4d900c32 100644 --- a/backends/r1cs/lib/r1cs/Transforms/TransformationPassPipelines.cpp +++ b/backends/r1cs/lib/r1cs/Transforms/TransformationPassPipelines.cpp @@ -12,9 +12,11 @@ /// //===----------------------------------------------------------------------===// +#include "r1cs/Transforms/TransformationPassPipelines.h" + #include "r1cs/Transforms/TransformationPasses.h" -#include "llzk/Transforms/LLZKTransformationPasses.h" +#include "llzk/Transforms/LLZKTransformationPassPipelines.h" #include #include @@ -24,22 +26,28 @@ using namespace mlir; namespace r1cs { -void registerTransformationPassPipelines() { - PassPipelineRegistration<>( - "llzk-full-r1cs-lowering", "Lower already-flattened polynomial constraints to r1cs", - [](OpPassManager &pm) { - // 1. Degree lowering - pm.addPass(llzk::createPolyLoweringPass(llzk::PolyLoweringPassOptions {.maxDegree = 2})); +void buildFullR1CSLoweringPipeline(OpPassManager &pm) { + // 1. Polynomial degree lowering and cleanup + llzk::FullPolyLoweringConfig config; + config.polyLowering = llzk::PolyLoweringPassOptions {.maxDegree = 2}; + llzk::buildFullPolyLoweringPipeline(pm, config); - // 2. Cleanup - llzk::addRemoveUnnecessaryOpsAndDefsPipeline(pm); + // 2. Convert to R1CS + pm.addPass(r1cs::createR1CSLoweringPass()); - // 3. Convert to R1CS - pm.addPass(r1cs::createR1CSLoweringPass()); + // 3. Run CSE to eliminate to_linear ops + pm.addPass(mlir::createCSEPass()); - // 4. Run CSE to eliminate to_linear ops - pm.addPass(mlir::createCSEPass()); - } + // Other passes that may be helpful to add in the future: + // - llzk::createRemoveDeadValuesWorkaroundPass() + // - mlir::createCanonicalizerPass() + // (was run via poly-lowering -> struct-inlining but again may be useful) +} + +void registerTransformationPassPipelines() { + PassPipelineRegistration<>( + "llzk-full-r1cs-lowering", "Lower polynomial constraints to r1cs", + buildFullR1CSLoweringPipeline ); } diff --git a/changelogs/unreleased/th__pass_pipelines.yaml b/changelogs/unreleased/th__pass_pipelines.yaml new file mode 100644 index 0000000000..4e83ee9afc --- /dev/null +++ b/changelogs/unreleased/th__pass_pipelines.yaml @@ -0,0 +1,8 @@ +added: + - "`llzk-full-struct-inlining` pass pipeline definition" + +changed: + - "[header incompatibility] move `llzk-inline-structs` pass definition to `llzk/Dialect/Struct/Transforms/TransformationPasses.h`" + - "[CLI incompatibility] `llzk-full-poly-lowering` option `max-degree` is now `lowering={max-degree=3}`" + - "run `llzk-full-struct-inlining` at start of `llzk-full-poly-lowering` pass pipeline" + - "run `llzk-full-poly-lowering` with `max-degree=2` at start of `llzk-full-r1cs-lowering` pass pipeline" diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index 1416c1760f..f93db820ec 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt @@ -42,6 +42,7 @@ add_subdirectory(llzk/Dialect/Shared) add_subdirectory(llzk/Dialect/SMT/IR) add_subdirectory(llzk/Dialect/String/IR) add_subdirectory(llzk/Dialect/Struct/IR) +add_subdirectory(llzk/Dialect/Struct/Transforms) add_subdirectory(llzk/Dialect/Verif/IR) add_subdirectory(llzk/Transforms) diff --git a/include/llzk-c/Dialect/Struct.h b/include/llzk-c/Dialect/Struct.h index c824bdaa4d..865516685e 100644 --- a/include/llzk-c/Dialect/Struct.h +++ b/include/llzk-c/Dialect/Struct.h @@ -29,6 +29,7 @@ // Include the generated CAPI #include "llzk/Dialect/Struct/IR/Ops.capi.h.inc" #include "llzk/Dialect/Struct/IR/Types.capi.h.inc" +#include "llzk/Dialect/Struct/Transforms/TransformationPasses.capi.h.inc" #ifdef __cplusplus extern "C" { diff --git a/include/llzk/Analysis/MemberOverwriteAnalysis.h b/include/llzk/Analysis/MemberOverwriteAnalysis.h index 7592d9050a..ae191d2751 100644 --- a/include/llzk/Analysis/MemberOverwriteAnalysis.h +++ b/include/llzk/Analysis/MemberOverwriteAnalysis.h @@ -149,4 +149,4 @@ class MemberOverwriteAnalysis void setToEntryState(MemberOverwriteLattice *lattice) override { lattice->entry(); } }; -}; // namespace llzk +} // namespace llzk diff --git a/include/llzk/Analysis/SparseAnalysis.h b/include/llzk/Analysis/SparseAnalysis.h index 4a0973971c..9fe027f77d 100644 --- a/include/llzk/Analysis/SparseAnalysis.h +++ b/include/llzk/Analysis/SparseAnalysis.h @@ -1,4 +1,4 @@ -//===- SparseAnalysis.h - Sparse data-flow analysis -----------------------===// +//===- SparseAnalysis.h - Sparse data-flow analysis -------------*- C++ -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. diff --git a/include/llzk/Dialect/Array/Transforms/TransformationPasses.h b/include/llzk/Dialect/Array/Transforms/TransformationPasses.h index 55e2ed50e4..c1b7b9df76 100644 --- a/include/llzk/Dialect/Array/Transforms/TransformationPasses.h +++ b/include/llzk/Dialect/Array/Transforms/TransformationPasses.h @@ -17,4 +17,4 @@ namespace llzk::array { #define GEN_PASS_REGISTRATION #include "llzk/Dialect/Array/Transforms/TransformationPasses.h.inc" -}; // namespace llzk::array +} // namespace llzk::array diff --git a/include/llzk/Dialect/Include/Transforms/InlineIncludesPass.h b/include/llzk/Dialect/Include/Transforms/InlineIncludesPass.h index ab479eb2eb..8bc0d2f418 100644 --- a/include/llzk/Dialect/Include/Transforms/InlineIncludesPass.h +++ b/include/llzk/Dialect/Include/Transforms/InlineIncludesPass.h @@ -18,4 +18,4 @@ namespace llzk::include { #define GEN_PASS_REGISTRATION #include "llzk/Dialect/Include/Transforms/InlineIncludesPass.h.inc" -}; // namespace llzk::include +} // namespace llzk::include diff --git a/include/llzk/Dialect/POD/Transforms/TransformationPasses.h b/include/llzk/Dialect/POD/Transforms/TransformationPasses.h index f4ef7a09a2..4c0fecd290 100644 --- a/include/llzk/Dialect/POD/Transforms/TransformationPasses.h +++ b/include/llzk/Dialect/POD/Transforms/TransformationPasses.h @@ -17,4 +17,4 @@ namespace llzk::pod { #define GEN_PASS_REGISTRATION #include "llzk/Dialect/POD/Transforms/TransformationPasses.h.inc" -}; // namespace llzk::pod +} // namespace llzk::pod diff --git a/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.h b/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.h index 3449428337..fcbb11e34d 100644 --- a/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.h +++ b/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.h @@ -18,4 +18,4 @@ namespace llzk::polymorphic { #define GEN_PASS_REGISTRATION #include "llzk/Dialect/Polymorphic/Transforms/TransformationPasses.h.inc" -}; // namespace llzk::polymorphic +} // namespace llzk::polymorphic diff --git a/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.td b/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.td index 51c9f20dcc..dd4292cfb0 100644 --- a/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.td +++ b/include/llzk/Dialect/Polymorphic/Transforms/TransformationPasses.td @@ -23,18 +23,21 @@ def FlatteningCleanupModeDescription { def FlatteningCleanupMode : I32EnumAttr<"FlatteningCleanupMode", FlatteningCleanupModeDescription.r, [ + // Unspecified: Leave the cleanup choice to the caller + // (defaults to `preimage` if not specified). + I32EnumAttrCase<"Unspecified", 0, "unspecified">, // Disabled: No definitions are deleted. - I32EnumAttrCase<"Disabled", 0, "disabled">, + I32EnumAttrCase<"Disabled", 1, "disabled">, // Preimage: Only definitions that were replaced with // concrete instantiations are deleted. - I32EnumAttrCase<"Preimage", 1, "preimage">, + I32EnumAttrCase<"Preimage", 2, "preimage">, // ConcreteAsRoot: All definitions that cannot be reached // by a use-def chain from some concrete definition are // deleted. - I32EnumAttrCase<"ConcreteAsRoot", 2, "concrete-as-root">, + I32EnumAttrCase<"ConcreteAsRoot", 3, "concrete-as-root">, // MainAsRoot: All definitions that cannot be reached by a // use-def chain from the main struct are deleted. - I32EnumAttrCase<"MainAsRoot", 3, "main-as-root">, + I32EnumAttrCase<"MainAsRoot", 4, "main-as-root">, ]> { let cppNamespace = "::llzk::polymorphic"; let genSpecializedAttr = 0; @@ -60,6 +63,8 @@ def FlatteningPass : LLZKPass<"llzk-flatten"> { to `compute()` functions and unroll loops - Unroll loops }]; + // Implementation note: These options should be kept in sync with + // `StructInliningFlatteningOptions` in `LLZKTransformationPassPipelines.h`. let options = [Option<"iterationLimit", "max-iter", "unsigned", /* default */ "1000", @@ -72,10 +77,14 @@ def FlatteningPass : LLZKPass<"llzk-flatten"> { Option<"cleanupMode", "cleanup", "::llzk::polymorphic::FlatteningCleanupMode", /* default */ - "::llzk::polymorphic::FlatteningCleanupMode::Preimage", + "::llzk::polymorphic::FlatteningCleanupMode::Unspecified", FlatteningCleanupModeDescription.r, [{::llvm::cl::values( + clEnumValN(::llzk::polymorphic::FlatteningCleanupMode::Unspecified, + stringifyFlatteningCleanupMode(::llzk::polymorphic::FlatteningCleanupMode::Unspecified), + "Use the cleanup mode specified by the calling pipeline (defaults to `preimage` if not specified)."), clEnumValN(::llzk::polymorphic::FlatteningCleanupMode::Disabled, - stringifyFlatteningCleanupMode(::llzk::polymorphic::FlatteningCleanupMode::Disabled), "No definitions are deleted."), + stringifyFlatteningCleanupMode(::llzk::polymorphic::FlatteningCleanupMode::Disabled), + "No definitions are deleted."), clEnumValN(::llzk::polymorphic::FlatteningCleanupMode::Preimage, stringifyFlatteningCleanupMode(::llzk::polymorphic::FlatteningCleanupMode::Preimage), "Only definitions that were replaced with concrete instantiations are deleted."), diff --git a/include/llzk/Dialect/SMT/IR/SMTAttributes.h b/include/llzk/Dialect/SMT/IR/SMTAttributes.h index 54dbb6cf83..d1ed9b4fac 100644 --- a/include/llzk/Dialect/SMT/IR/SMTAttributes.h +++ b/include/llzk/Dialect/SMT/IR/SMTAttributes.h @@ -1,4 +1,4 @@ -//===- SMTAttributes.h - Declare SMT dialect attributes ----------*- C++-*-===// +//===- SMTAttributes.h - Declare SMT dialect attributes ---------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/include/llzk/Dialect/Struct/Transforms/CMakeLists.txt b/include/llzk/Dialect/Struct/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..6a4524ec14 --- /dev/null +++ b/include/llzk/Dialect/Struct/Transforms/CMakeLists.txt @@ -0,0 +1,10 @@ +include_directories(${MLIR_INCLUDE_DIRS} ${LLZK_INCLUDE_DIR}) + +set(LLVM_TARGET_DEFINITIONS "TransformationPasses.td") +mlir_tablegen(TransformationPasses.h.inc -gen-pass-decls -name=Transformation) +mlir_tablegen(TransformationPasses.capi.h.inc -gen-pass-capi-header --prefix LLZKStructTransformation) +mlir_tablegen(TransformationPasses.capi.cpp.inc -gen-pass-capi-impl --prefix LLZKStructTransformation) +llzk_add_mlir_doc(StructTransformationPassesDocGen passes/struct/TransformationPasses.md -gen-pass-doc) + +add_public_tablegen_target(StructTransformationIncGen) +add_dependencies(LLZKDialectHeaders StructTransformationIncGen) diff --git a/include/llzk/Transforms/LLZKInlineStructsPass.h b/include/llzk/Dialect/Struct/Transforms/InlineStructsPass.h similarity index 94% rename from include/llzk/Transforms/LLZKInlineStructsPass.h rename to include/llzk/Dialect/Struct/Transforms/InlineStructsPass.h index eb7815585b..9dd7cd363b 100644 --- a/include/llzk/Transforms/LLZKInlineStructsPass.h +++ b/include/llzk/Dialect/Struct/Transforms/InlineStructsPass.h @@ -1,4 +1,4 @@ -//===-- LLZKInlineStructsPass.h ---------------------------------*- C++ -*-===// +//===-- InlineStructsPass.h -------------------------------------*- C++ -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. diff --git a/include/llzk/Dialect/Struct/Transforms/TransformationPasses.h b/include/llzk/Dialect/Struct/Transforms/TransformationPasses.h new file mode 100644 index 0000000000..f4e0ff65d3 --- /dev/null +++ b/include/llzk/Dialect/Struct/Transforms/TransformationPasses.h @@ -0,0 +1,20 @@ +//===-- TransformationPasses.h ---------------------------------*- C++ -*-===// +// +// Part of the LLZK Project, under the Apache License v2.0. +// See LICENSE.txt for license information. +// Copyright 2026 Project LLZK +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "llzk/Pass/PassBase.h" + +namespace llzk::component { + +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "llzk/Dialect/Struct/Transforms/TransformationPasses.h.inc" + +} // namespace llzk::component diff --git a/include/llzk/Dialect/Struct/Transforms/TransformationPasses.td b/include/llzk/Dialect/Struct/Transforms/TransformationPasses.td new file mode 100644 index 0000000000..1973a63b8c --- /dev/null +++ b/include/llzk/Dialect/Struct/Transforms/TransformationPasses.td @@ -0,0 +1,41 @@ +//===-- TransformationPasses.td ----------------------------*- tablegen -*-===// +// +// Part of the LLZK Project, under the Apache License v2.0. +// See LICENSE.txt for license information. +// Copyright 2026 Project LLZK +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#ifndef LLZK_DIALECT_STRUCT_TRANSFORMS_TRANSFORMATION_PASSES_TD +#define LLZK_DIALECT_STRUCT_TRANSFORMS_TRANSFORMATION_PASSES_TD + +include "llzk/Pass/PassBase.td" + +def InlineStructsPass : LLZKPass<"llzk-inline-structs"> { + let summary = "Inlines nested structs (i.e., subcomponents)."; + let description = [{ + This pass inlines nested structs (i.e., subcomponents) at struct-type members and at calls to the + subcomponent compute/constrain functions. Inlining decisions are guided by the call graph of + "constrain" functions. + + The `max-merge-complexity` parameter can be used to limit the complexity of the resulting structs such + that a potential inlining will not take place if doing so would push the sum of constraint and + multiplications in the combined struct over the limit. The default value `0` indicates no limits + which means all structs will be inlined into the Main struct. + + This pass should be run after `llzk-flatten` to ensure structs do not have template parameters + because structs with template parameters cannot (currently) be inlined. Inlining is also not + (currently) supported for subcomponent structs stored in an array-type member. + + This pass also assumes that all subcomponents that are created by calling a struct "@compute" + function are ultimately written to exactly one member within the current struct. + }]; + let options = [Option<"maxComplexity", "max-merge-complexity", "uint64_t", + /* default: no limit */ "0", + "Maximum allowed constraint+multiplications in merged " + "@constrain functions">, + ]; +} + +#endif // LLZK_DIALECT_STRUCT_TRANSFORMS_TRANSFORMATION_PASSES_TD diff --git a/include/llzk/Transforms/LLZKTransformationPassPipelines.h b/include/llzk/Transforms/LLZKTransformationPassPipelines.h new file mode 100644 index 0000000000..f90d3cf98f --- /dev/null +++ b/include/llzk/Transforms/LLZKTransformationPassPipelines.h @@ -0,0 +1,130 @@ +//===-- LLZKTransformationPassPipelines.h -----------------------*- C++ -*-===// +// +// Part of the LLZK Project, under the Apache License v2.0. +// See LICENSE.txt for license information. +// Copyright 2026 Project LLZK +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "llzk/Dialect/Polymorphic/Transforms/TransformationPasses.h" +#include "llzk/Dialect/Struct/Transforms/TransformationPasses.h" +#include "llzk/Transforms/LLZKTransformationPasses.h" +#include "llzk/Transforms/Parsers.h" + +#include +#include + +namespace llzk { + +/// Typed nested options for the flattening pass when used inside the full +/// struct-inlining and full poly-lowering pipelines. +struct StructInliningFlatteningOptions + : public mlir::PassPipelineOptions { + // Implementation note: these options should be kept in sync with the `FlatteningPass` ODS. + + Option iterationLimit { + *this, "max-iter", llvm::cl::desc("maximum number of flattening iterations before giving up"), + llvm::cl::init(1000) + }; + + Option cleanupMode { + *this, "cleanup", + llvm::cl::desc( + "cleanup mode for flattening in this pipeline. When left as `unspecified`, these " + "pipelines use `main-as-root`. Overriding this is not recommended because the later " + "`llzk-inline-structs` pass may crash if parameterized templates survive flattening." + ), + llvm::cl::init(polymorphic::FlatteningCleanupMode::Unspecified) + }; + + polymorphic::FlatteningPassOptions createPassOptions() const { + return polymorphic::FlatteningPassOptions { + .iterationLimit = iterationLimit, .cleanupMode = cleanupMode + }; + } +}; + +/// Pure C++ configuration for the full struct inlining pipeline. +struct FullStructInliningConfig { + polymorphic::FlatteningPassOptions flattening; + bool arrayToScalar = true; + bool podToScalar = true; + component::InlineStructsPassOptions inlining; +}; + +/// CLI Option configuration for the full struct inlining pipeline. +struct FullStructInliningOptions : public mlir::PassPipelineOptions { + + using FlatteningOptions = NestedPipelineOptions; + + using InliningOptions = NestedPassOptions< + static_cast (*)()>(&llzk::component::createInlineStructsPass)>; + + Option flattening { + *this, "flattening", + llvm::cl::desc( + "options for the flattening pass used in this pipeline; this pipeline defaults " + "flattening pass cleanup to `main-as-root`" + ), + llvm::cl::init(FlatteningOptions {}) + }; + Option arrayToScalar { + *this, "array-to-scalar", + llvm::cl::desc("whether to run the array-to-scalar pass in this pipeline"), + llvm::cl::init(true) + }; + Option podToScalar { + *this, "pod-to-scalar", + llvm::cl::desc("whether to run the pod-to-scalar pass in this pipeline"), llvm::cl::init(true) + }; + Option inlining { + *this, "inlining", llvm::cl::desc("options for the inlining pass used in this pipeline"), + llvm::cl::init(InliningOptions {}) + }; +}; + +/// Pure C++ configuration for the full polynomial lowering pipeline. +struct FullPolyLoweringConfig { + FullStructInliningConfig structInlining; + PolyLoweringPassOptions polyLowering; +}; + +/// CLI Option configuration for the full polynomial lowering pipeline. +struct FullPolyLoweringOptions : public mlir::PassPipelineOptions { + + using StructInliningOptions = NestedPipelineOptions; + + using PolyLoweringOptions = NestedPassOptions< + static_cast (*)()>(&llzk::createPolyLoweringPass)>; + + Option structInlining { + *this, "flatten-inline", + llvm::cl::desc( + "options for the struct flattening and inlining pipeline used before polynomial " + "lowering; this pipeline defaults flattening cleanup to `main-as-root`" + ), + llvm::cl::init(StructInliningOptions {}) + }; + Option polyLowering { + *this, "lowering", + llvm::cl::desc("options for the polynomial lowering pass used in this pipeline"), + llvm::cl::init(PolyLoweringOptions {}) + }; +}; + +void buildRemoveUnnecessaryOpsPipeline(mlir::OpPassManager &); + +void buildRemoveUnnecessaryOpsAndDefsPipeline(mlir::OpPassManager &); + +void buildFullPolyLoweringPipeline(mlir::OpPassManager &, const FullPolyLoweringConfig &); + +void buildProductProgramPipeline(mlir::OpPassManager &); + +void buildFullStructInliningPipeline(mlir::OpPassManager &, const FullStructInliningConfig &); + +void registerTransformationPassPipelines(); + +} // namespace llzk diff --git a/include/llzk/Transforms/LLZKTransformationPasses.h b/include/llzk/Transforms/LLZKTransformationPasses.h index 7076b06fbc..eb6771f81b 100644 --- a/include/llzk/Transforms/LLZKTransformationPasses.h +++ b/include/llzk/Transforms/LLZKTransformationPasses.h @@ -16,14 +16,10 @@ namespace llzk { -void addRemoveUnnecessaryOpsAndDefsPipeline(mlir::OpPassManager &pm); - -void registerTransformationPassPipelines(); - void registerInliningExtensions(mlir::DialectRegistry ®istry); #define GEN_PASS_DECL #define GEN_PASS_REGISTRATION #include "llzk/Transforms/LLZKTransformationPasses.h.inc" -}; // namespace llzk +} // namespace llzk diff --git a/include/llzk/Transforms/LLZKTransformationPasses.td b/include/llzk/Transforms/LLZKTransformationPasses.td index 8172f80d4f..88c95c66c3 100644 --- a/include/llzk/Transforms/LLZKTransformationPasses.td +++ b/include/llzk/Transforms/LLZKTransformationPasses.td @@ -90,32 +90,6 @@ def PolyLoweringPass : LLZKPass<"llzk-poly-lowering-pass"> { ]; } -def InlineStructsPass : LLZKPass<"llzk-inline-structs"> { - let summary = "Inlines nested structs (i.e., subcomponents)."; - let description = [{ - This pass inlines nested structs (i.e., subcomponents) at struct-type members and at calls to the - subcomponent compute/constrain functions. Inlining decisions are guided by the call graph of - "constrain" functions. - - The `max-merge-complexity` parameter can be used to limit the complexity of the resulting structs such - that a potential inlining will not take place if doing so would push the sum of constraint and - multiplications in the combined struct over the limit. The default value `0` indicates no limits - which means all structs will be inlined into the Main struct. - - This pass should be run after `llzk-flatten` to ensure structs do not have template parameters - because structs with template parameters cannot (currently) be inlined. Inlining is also not - (currently) supported for subcomponent structs stored in an array-type member. - - This pass also assumes that all subcomponents that are created by calling a struct "@compute" - function are ultimately written to exactly one member within the current struct. - }]; - let options = [Option<"maxComplexity", "max-merge-complexity", "uint64_t", - /* default */ "0", - "Maximum allowed constraint+multiplications in merged " - "@constrain functions">, - ]; -} - def ComputeConstrainToProductPass : LLZKPass<"llzk-compute-constrain-to-product"> { let summary = "Replace separate @compute and @constrain functions in a " diff --git a/include/llzk/Transforms/Parsers.h b/include/llzk/Transforms/Parsers.h index 806be8b4a1..a64c7d1e49 100644 --- a/include/llzk/Transforms/Parsers.h +++ b/include/llzk/Transforms/Parsers.h @@ -5,6 +5,7 @@ // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. // Copyright 2025 Veridise Inc. +// Copyright 2026 Project LLZK // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// @@ -13,14 +14,180 @@ #include "llzk/Util/Compare.h" +#include + #include #include #include +#include +#include + +#include + +namespace llzk { + +namespace detail { + +/// Shared storage and helpers for nested textual pass and pipeline options. +struct NestedTextualOptions { + /// The validated textual form without the outer `{...}` delimiters. + std::string str; + + /// Recreate a nested value from the stored option string and re-validate it. + /// + /// The CLI parser validates `str` before storing it, but callers still need a + /// fresh initialized pass or pipeline options object when materializing the + /// nested configuration later. + template + std::unique_ptr createValidatedValue( + CreateFnT &&createValue, llvm::StringRef kind, InitializeFnT &&initializeValue + ) const { + auto value = createValue(); + if (str.empty()) { + return value; + } + + std::string error; + if (failed(initializeValue(*value, str, error))) { + llvm::report_fatal_error( + llvm::Twine("failed to initialize previously-validated nested ") + kind + + " options: " + error + ); + } + return value; + } +}; + +} // namespace detail + +/// Stores textual options for a constituent pass after validating them against +/// that pass' native MLIR option parser. +template struct NestedPassOptions : detail::NestedTextualOptions { + /// Build a fresh pass instance with the validated options applied. + std::unique_ptr createPass() const { + return this->createValidatedValue(CreatePass, "pass", initializePass); + } + + static mlir::LogicalResult + initializePass(mlir::Pass &pass, llvm::StringRef options, std::string &error) { + return pass.initializeOptions(options, [&error](const llvm::Twine &message) { + error = message.str(); + return mlir::failure(); + }); + } +}; + +/// Stores textual options for a constituent pipeline after validating them +/// against that pipeline's native MLIR option parser. +template struct NestedPipelineOptions : detail::NestedTextualOptions { + /// Build a fresh options object with the validated options applied. + std::unique_ptr createOptions() const { + return this->createValidatedValue( + std::make_unique, "pipeline", initializeOptions + ); + } + + static mlir::LogicalResult + initializeOptions(PipelineOptionsT &options, llvm::StringRef value, std::string &error) { + llvm::raw_string_ostream errorStream(error); + return options.parseFromString(value, errorStream); + } +}; + +} // namespace llzk // Custom command line parsers namespace llvm { namespace cl { +template class NestedOptionsParserBase : public basic_parser { +public: + NestedOptionsParserBase(Option &O) : basic_parser(O) {} + +protected: + /// Parse a nested pass or pipeline option payload, optionally stripping a + /// surrounding `{...}` wrapper. + bool parseNestedOptions(Option &O, StringRef Arg, StringRef kind, StringRef &options) const { + options = Arg; + if (options.consume_front("{") && !options.consume_back("}")) { + return O.error(llvm::Twine("expected nested ") + kind + " options to end with '}'"); + } + return false; + } + +public: + static void print(llvm::raw_ostream &OS, const OptionsT &Val) { OS << '{' << Val.str << '}'; } + + void printOptionDiff( + const Option &O, const OptionsT &V, const OptionValue &Default, size_t GlobalWidth + ) const { + this->printOptionName(O, GlobalWidth); + print(llvm::outs(), V); + llvm::outs() << " (default: "; + if (Default.hasValue()) { + print(llvm::outs(), Default.getValue()); + } else { + llvm::outs() << ""; + } + llvm::outs() << ")\n"; + } +}; + +/// Parser for textual options that are validated by a constituent MLIR pass. +template +class parser> + : public NestedOptionsParserBase> { +public: + using OptionsT = llzk::NestedPassOptions; + using Base = NestedOptionsParserBase; + + parser(Option &O) : Base(O) {} + + bool parse(Option &O, StringRef, StringRef Arg, OptionsT &Val) { + StringRef options; + if (this->parseNestedOptions(O, Arg, "pass", options)) { + return true; + } + + auto pass = CreatePass(); + std::string error; + if (failed(OptionsT::initializePass(*pass, options, error))) { + return O.error(error); + } + + Val.str = options.str(); + return false; + } +}; + +/// Parser for textual options that are validated by a constituent MLIR +/// pipeline. +template +class parser> + : public NestedOptionsParserBase> { +public: + using OptionsT = llzk::NestedPipelineOptions; + using Base = NestedOptionsParserBase; + + parser(Option &O) : Base(O) {} + + bool parse(Option &O, StringRef, StringRef Arg, OptionsT &Val) { + StringRef options; + if (this->parseNestedOptions(O, Arg, "pipeline", options)) { + return true; + } + + PipelineOptionsT pipelineOptions; + std::string error; + if (failed(OptionsT::initializeOptions(pipelineOptions, options, error))) { + return O.error(error); + } + + Val.str = options.str(); + return false; + } +}; + // Parser for APInt template <> class parser : public basic_parser { public: diff --git a/include/llzk/Util/SymbolTableLLZK.h b/include/llzk/Util/SymbolTableLLZK.h index 4b4edc5203..a5add10e4c 100644 --- a/include/llzk/Util/SymbolTableLLZK.h +++ b/include/llzk/Util/SymbolTableLLZK.h @@ -1,4 +1,4 @@ -//===-- SymbolTableLLZK.h ------------------------------------------*- C++ -*-===// +//===-- SymbolTableLLZK.h ---------------------------------------*- C++ -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. diff --git a/include/llzk/Validators/LLZKValidationPasses.h b/include/llzk/Validators/LLZKValidationPasses.h index ad84cf7a8f..5e31cff78c 100644 --- a/include/llzk/Validators/LLZKValidationPasses.h +++ b/include/llzk/Validators/LLZKValidationPasses.h @@ -17,4 +17,4 @@ namespace llzk { #define GEN_PASS_REGISTRATION #include "llzk/Validators/LLZKValidationPasses.h.inc" -}; // namespace llzk +} // namespace llzk diff --git a/lib/Analysis/LightweightSignalEquivalenceAnalysis.cpp b/lib/Analysis/LightweightSignalEquivalenceAnalysis.cpp index c9ecc2200f..b795fcd728 100644 --- a/lib/Analysis/LightweightSignalEquivalenceAnalysis.cpp +++ b/lib/Analysis/LightweightSignalEquivalenceAnalysis.cpp @@ -1,4 +1,4 @@ -//===- LightweightSignalEquivalenceAnalysis.cpp ---------------------------===// +//===- LightweightSignalEquivalenceAnalysis.cpp -----------------*- C++ -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. diff --git a/lib/Analysis/SparseAnalysis.cpp b/lib/Analysis/SparseAnalysis.cpp index 031a7bb81e..3c289ca1f9 100644 --- a/lib/Analysis/SparseAnalysis.cpp +++ b/lib/Analysis/SparseAnalysis.cpp @@ -1,4 +1,4 @@ -//===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===// +//===- SparseAnalysis.cpp - Sparse data-flow analysis -----------*- C++ -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. @@ -10,7 +10,6 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// //===----------------------------------------------------------------------===// #include "llzk/Analysis/SparseAnalysis.h" diff --git a/lib/CAPI/Dialect/Array.cpp b/lib/CAPI/Dialect/Array.cpp index 9b67584ceb..4204ae20c0 100644 --- a/lib/CAPI/Dialect/Array.cpp +++ b/lib/CAPI/Dialect/Array.cpp @@ -30,7 +30,7 @@ using namespace mlir; using namespace llzk; using namespace llzk::array; -static void registerLLZKArrayTransformationPasses() { registerTransformationPasses(); } +static inline void registerLLZKArrayTransformationPasses() { registerTransformationPasses(); } // Include the generated CAPI #include "llzk/Dialect/Array/IR/Ops.capi.cpp.inc" diff --git a/lib/CAPI/Dialect/Include.cpp b/lib/CAPI/Dialect/Include.cpp index 4560c084d8..a2e935f25f 100644 --- a/lib/CAPI/Dialect/Include.cpp +++ b/lib/CAPI/Dialect/Include.cpp @@ -22,7 +22,7 @@ using namespace llzk::include; -static void registerLLZKIncludeTransformationPasses() { registerTransformationPasses(); } +static inline void registerLLZKIncludeTransformationPasses() { registerTransformationPasses(); } // Include the generated CAPI #include "llzk/Dialect/Include/IR/Ops.capi.cpp.inc" diff --git a/lib/CAPI/Dialect/POD.cpp b/lib/CAPI/Dialect/POD.cpp index eaad94df36..6677ea296d 100644 --- a/lib/CAPI/Dialect/POD.cpp +++ b/lib/CAPI/Dialect/POD.cpp @@ -39,7 +39,7 @@ using namespace mlir; using namespace llzk; using namespace llzk::pod; -static void registerLLZKPodTransformationPasses() { registerTransformationPasses(); } +static inline void registerLLZKPodTransformationPasses() { registerTransformationPasses(); } // Include the generated CAPI #include "llzk/Dialect/POD/IR/Attrs.capi.cpp.inc" diff --git a/lib/CAPI/Dialect/Poly.cpp b/lib/CAPI/Dialect/Poly.cpp index 3a909e856f..cc2920cc8e 100644 --- a/lib/CAPI/Dialect/Poly.cpp +++ b/lib/CAPI/Dialect/Poly.cpp @@ -31,7 +31,7 @@ using namespace mlir; using namespace llzk; using namespace llzk::polymorphic; -static void registerLLZKPolymorphicTransformationPasses() { registerTransformationPasses(); } +static inline void registerLLZKPolymorphicTransformationPasses() { registerTransformationPasses(); } // Include the generated CAPI #include "llzk/Dialect/Polymorphic/IR/Ops.capi.cpp.inc" diff --git a/lib/CAPI/Dialect/Struct.cpp b/lib/CAPI/Dialect/Struct.cpp index ef45b8b3f9..2845acf9ad 100644 --- a/lib/CAPI/Dialect/Struct.cpp +++ b/lib/CAPI/Dialect/Struct.cpp @@ -15,14 +15,17 @@ #include "llzk/Dialect/Struct/IR/Dialect.h" #include "llzk/Dialect/Struct/IR/Ops.h" #include "llzk/Dialect/Struct/IR/Types.h" +#include "llzk/Dialect/Struct/Transforms/TransformationPasses.h" #include "llzk/Util/Compare.h" #include "llzk/Util/SymbolLookup.h" #include "llzk/Util/TypeHelper.h" #include +#include #include #include +#include #include #include #include @@ -35,9 +38,12 @@ using namespace mlir; using namespace llzk; using namespace llzk::component; +static inline void registerLLZKStructTransformationPasses() { registerTransformationPasses(); } + // Include the generated CAPI #include "llzk/Dialect/Struct/IR/Ops.capi.cpp.inc" #include "llzk/Dialect/Struct/IR/Types.capi.cpp.inc" +#include "llzk/Dialect/Struct/Transforms/TransformationPasses.capi.cpp.inc" MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Struct, llzk__component, StructDialect) diff --git a/lib/CAPI/Transforms.cpp b/lib/CAPI/Transforms.cpp index 494fe38d70..d1baec226d 100644 --- a/lib/CAPI/Transforms.cpp +++ b/lib/CAPI/Transforms.cpp @@ -14,7 +14,7 @@ using namespace llzk; -static void registerLLZKTransformationPasses() { registerTransformationPasses(); } +static inline void registerLLZKTransformationPasses() { registerTransformationPasses(); } // Impl #include "llzk/Transforms/LLZKTransformationPasses.capi.cpp.inc" diff --git a/lib/Dialect/Polymorphic/CMakeLists.txt b/lib/Dialect/Polymorphic/CMakeLists.txt index a2a1352c48..295847ced6 100644 --- a/lib/Dialect/Polymorphic/CMakeLists.txt +++ b/lib/Dialect/Polymorphic/CMakeLists.txt @@ -5,8 +5,16 @@ target_link_libraries(LLZKAllDialects INTERFACE LLZKPolymorphicDialect) file(GLOB_RECURSE LLZKPolymorphicDialect_SOURCES "**/*.cpp") target_sources(LLZKPolymorphicDialect PRIVATE ${LLZKPolymorphicDialect_SOURCES}) target_link_libraries( - LLZKPolymorphicDialect PUBLIC LLZKDialectHeaders ${LLZK_DEP_DIALECT_LIBS} MLIRIR - MLIRParser LLVMHeaders MLIRHeaders) + LLZKPolymorphicDialect + PUBLIC + LLZKDialectHeaders + ${LLZK_DEP_DIALECT_LIBS} + MLIRIR + MLIRParser + LLVMHeaders + MLIRHeaders + PRIVATE + LLZKTransforms) llzk_target_add_mlir_link_settings(LLZKPolymorphicDialect) install(TARGETS LLZKPolymorphicDialect EXPORT LLZKTargets) diff --git a/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp b/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp index a56de79236..6d63de0851 100644 --- a/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp +++ b/lib/Dialect/Polymorphic/Transforms/FlatteningPass.cpp @@ -2764,6 +2764,12 @@ class PassImpl : public llzk::polymorphic::impl::FlatteningPassBase { using Base = FlatteningPassBase; using Base::Base; + /// If the cleanup mode is unspecified, default to `Preimage`. + FlatteningCleanupMode getEffectiveCleanupMode() const { + FlatteningCleanupMode m = cleanupMode.getValue(); + return m == FlatteningCleanupMode::Unspecified ? FlatteningCleanupMode::Preimage : m; + } + void runOnOperation() override { ModuleOp modOp = getOperation(); if (failed(runOn(modOp))) { @@ -2779,11 +2785,12 @@ class PassImpl : public llzk::polymorphic::impl::FlatteningPassBase { } inline LogicalResult runOn(ModuleOp modOp) { + FlatteningCleanupMode effectiveCleanupMode = getEffectiveCleanupMode(); // If the cleanup mode is set to remove anything not reachable from the main struct, do an // initial pass to remove things that are not reachable (as an optimization) because creating // an instantiated version of a struct will not cause something to become reachable that was // not already reachable in parameterized form. - if (cleanupMode == FlatteningCleanupMode::MainAsRoot) { + if (effectiveCleanupMode == FlatteningCleanupMode::MainAsRoot) { if (failed(eraseUnreachableFromMainStruct(modOp))) { return failure(); } @@ -2887,8 +2894,9 @@ class PassImpl : public llzk::polymorphic::impl::FlatteningPassBase { // Perform cleanup according to the 'cleanupMode' option. LogicalResult cleanupSwitch(ModuleOp modOp, const ConversionTracker &tracker) { + FlatteningCleanupMode effectiveCleanupMode = getEffectiveCleanupMode(); LLVM_DEBUG({ llvm::dbgs() << "[FlatteningPass] Running step 5: cleanup "; }); - switch (cleanupMode) { + switch (effectiveCleanupMode) { case FlatteningCleanupMode::MainAsRoot: LLVM_DEBUG(llvm::dbgs() << "(main as root mode)\n"); return eraseUnreachableFromMainStruct(modOp, false); @@ -2898,6 +2906,7 @@ class PassImpl : public llzk::polymorphic::impl::FlatteningPassBase { case FlatteningCleanupMode::Preimage: LLVM_DEBUG(llvm::dbgs() << "(preimage mode)\n"); return erasePreimageOfInstantiations(modOp, tracker); + case FlatteningCleanupMode::Unspecified: default: LLVM_DEBUG(llvm::dbgs() << "(disabled)\n"); return success(); diff --git a/lib/Dialect/SMT/SMTAttributes.cpp b/lib/Dialect/SMT/SMTAttributes.cpp index f11e9d0f13..e1024eef7e 100644 --- a/lib/Dialect/SMT/SMTAttributes.cpp +++ b/lib/Dialect/SMT/SMTAttributes.cpp @@ -1,4 +1,4 @@ -//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===// +//===- SMTAttributes.cpp - Implement SMT attributes -------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/Dialect/SMT/SMTDialect.cpp b/lib/Dialect/SMT/SMTDialect.cpp index 0f9741729c..af5ceb3e35 100644 --- a/lib/Dialect/SMT/SMTDialect.cpp +++ b/lib/Dialect/SMT/SMTDialect.cpp @@ -1,4 +1,4 @@ -//===- SMTDialect.cpp - SMT dialect implementation ------------------------===// +//===- SMTDialect.cpp - SMT dialect implementation --------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/Dialect/SMT/SMTOps.cpp b/lib/Dialect/SMT/SMTOps.cpp index bd78348448..d87bf3d7d0 100644 --- a/lib/Dialect/SMT/SMTOps.cpp +++ b/lib/Dialect/SMT/SMTOps.cpp @@ -1,4 +1,4 @@ -//===- SMTOps.cpp ---------------------------------------------------------===// +//===- SMTOps.cpp -----------------------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/Dialect/SMT/SMTTypes.cpp b/lib/Dialect/SMT/SMTTypes.cpp index e5813255a6..0d5de6c21a 100644 --- a/lib/Dialect/SMT/SMTTypes.cpp +++ b/lib/Dialect/SMT/SMTTypes.cpp @@ -1,4 +1,4 @@ -//===- SMTTypes.cpp -------------------------------------------------------===// +//===- SMTTypes.cpp ---------------------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/Dialect/Struct/CMakeLists.txt b/lib/Dialect/Struct/CMakeLists.txt index b72991ad1b..c2e3c43da1 100644 --- a/lib/Dialect/Struct/CMakeLists.txt +++ b/lib/Dialect/Struct/CMakeLists.txt @@ -7,7 +7,7 @@ target_sources(LLZKStructDialect PRIVATE ${LLZKStructDialect_SOURCES}) target_link_libraries( LLZKStructDialect PUBLIC LLZKDialectHeaders ${LLZK_DEP_DIALECT_LIBS} MLIRIR MLIRPass MLIRParser MLIRTransformUtils MLIRSCFTransforms - LLVMHeaders MLIRHeaders LLZKUtil LLZKDialect) + LLVMHeaders MLIRHeaders LLZKAnalysis LLZKUtil LLZKDialect) llzk_target_add_mlir_link_settings(LLZKStructDialect) install(TARGETS LLZKStructDialect EXPORT LLZKTargets) diff --git a/lib/Transforms/LLZKInlineStructsPass.cpp b/lib/Dialect/Struct/Transforms/InlineStructsPass.cpp similarity index 97% rename from lib/Transforms/LLZKInlineStructsPass.cpp rename to lib/Dialect/Struct/Transforms/InlineStructsPass.cpp index af42649911..eb9de621d2 100644 --- a/lib/Transforms/LLZKInlineStructsPass.cpp +++ b/lib/Dialect/Struct/Transforms/InlineStructsPass.cpp @@ -1,4 +1,4 @@ -//===-- LLZKInlineStructsPass.cpp -------------------------------*- C++ -*-===// +//===-- InlineStructsPass.cpp -----------------------------------*- C++ -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. @@ -18,7 +18,7 @@ /// //===----------------------------------------------------------------------===// -#include "llzk/Transforms/LLZKInlineStructsPass.h" +#include "llzk/Dialect/Struct/Transforms/InlineStructsPass.h" #include "llzk/Analysis/SymbolUseGraph.h" #include "llzk/Dialect/Constrain/IR/Ops.h" @@ -26,8 +26,8 @@ #include "llzk/Dialect/Function/IR/Ops.h" #include "llzk/Dialect/Polymorphic/IR/Ops.h" #include "llzk/Dialect/Struct/IR/Ops.h" +#include "llzk/Dialect/Struct/Transforms/TransformationPasses.h" #include "llzk/Transforms/LLZKConversionUtils.h" -#include "llzk/Transforms/LLZKTransformationPasses.h" #include "llzk/Util/Debug.h" #include "llzk/Util/SymbolHelper.h" #include "llzk/Util/SymbolLookup.h" @@ -47,10 +47,10 @@ #include // Include the generated base pass class definitions. -namespace llzk { +namespace llzk::component { #define GEN_PASS_DEF_INLINESTRUCTSPASS -#include "llzk/Transforms/LLZKTransformationPasses.h.inc" -} // namespace llzk +#include "llzk/Dialect/Struct/Transforms/TransformationPasses.h.inc" +} // namespace llzk::component using namespace mlir; using namespace llzk; @@ -595,6 +595,11 @@ class DanglingUseHandler { } private: + /// Rewrite a call that still consumes `origin` after inlining so the callee takes each cloned + /// member as a separate argument and the call site materializes matching `struct.readm` values. + /// + /// This only supports calls to resolvable, non-external functions because the pass must update + /// both the call operation and the callee signature in lockstep. inline LogicalResult handleUseInCallOp(OpOperand &use, CallOp inCall, Operation *origin) const { LLVM_DEBUG( llvm::dbgs() << "[DanglingUseHandler::handleUseInCallOp] use in call: " << inCall << '\n' @@ -806,6 +811,9 @@ class DanglingUseHandler { } }; +/// Apply the post-inlining cleanup for one caller struct by folding rewritten member-read +/// chains, resolving dangling uses of soon-to-be-erased ops, and then erasing the obsolete +/// cloned scaffolding in dependency order. static LogicalResult finalizeStruct( SymbolTableCollection &tables, StructDefOp caller, PendingErasure &&toDelete, DestToSrcToClonedSrcInDest &&destToSrcToClone @@ -894,6 +902,8 @@ static LogicalResult finalizeStruct( } // namespace +/// Execute the inlining plan one caller struct at a time, accumulating per-callee member +/// replacement maps and then finalizing each caller after all requested callees have been inlined. LogicalResult performInlining(SymbolTableCollection &tables, InliningPlan &plan) { for (auto &[caller, callees] : plan) { // Cache operations that should be deleted but must wait until all callees are processed @@ -926,7 +936,7 @@ LogicalResult performInlining(SymbolTableCollection &tables, InliningPlan &plan) namespace { -class PassImpl : public llzk::impl::InlineStructsPassBase { +class PassImpl : public llzk::component::impl::InlineStructsPassBase { using Base = InlineStructsPassBase; using Base::Base; @@ -1007,6 +1017,8 @@ class PassImpl : public llzk::impl::InlineStructsPassBase { return res.wasInterrupted(); } + /// Reject symbol-use graphs that contain references to template `param` or `expr` ops. This pass + /// only supports concrete struct instances (run `llzk-flatten` to instantiate templates first). static LogicalResult verifyNoTemplateSymbolBindings(const SymbolUseGraph &useGraph, SymbolTableCollection &tables) { for (const SymbolUseGraphNode *node : useGraph.nodesIter()) { @@ -1026,6 +1038,8 @@ class PassImpl : public llzk::impl::InlineStructsPassBase { return success(); } + /// Emit a diagnostic for a cycle discovered while traversing the symbol-use slice reachable + /// from struct "@constrain" functions, attaching notes for each symbol in the cycle. static LogicalResult emitConstrainReachableCycleError( ArrayRef dfsStack, const SymbolUseGraphNode *cycleHead, SymbolTableCollection &tables diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index 442b963b42..bc7c27fb38 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -15,8 +15,8 @@ target_link_libraries( ${LLZK_DEP_DIALECT_LIBS} MLIRIR MLIRPass MLIRParser MLIRTransformUtils MLIRSCFTransforms LLVMHeaders MLIRHeaders + LLZKAllDialects PRIVATE - LLZKDialect LLZKAnalysis ) diff --git a/lib/Transforms/LLZKComputeConstrainToProductPass.cpp b/lib/Transforms/LLZKComputeConstrainToProductPass.cpp index 99135153be..f4fc1d28bd 100644 --- a/lib/Transforms/LLZKComputeConstrainToProductPass.cpp +++ b/lib/Transforms/LLZKComputeConstrainToProductPass.cpp @@ -17,7 +17,7 @@ #include "llzk/Analysis/LightweightSignalEquivalenceAnalysis.h" #include "llzk/Dialect/Function/IR/Ops.h" #include "llzk/Dialect/Struct/IR/Ops.h" -#include "llzk/Transforms/LLZKInlineStructsPass.h" +#include "llzk/Dialect/Struct/Transforms/InlineStructsPass.h" #include "llzk/Transforms/LLZKTransformationPasses.h" #include "llzk/Util/Constants.h" #include "llzk/Util/SymbolHelper.h" diff --git a/lib/Transforms/LLZKFuseProductLoopsPass.cpp b/lib/Transforms/LLZKFuseProductLoopsPass.cpp index 16cf02f97c..441287ec0b 100644 --- a/lib/Transforms/LLZKFuseProductLoopsPass.cpp +++ b/lib/Transforms/LLZKFuseProductLoopsPass.cpp @@ -1,4 +1,4 @@ -//===-- LLZKFuseProductLoopsPass.cpp -----------------------------*- C++ -*-===// +//===-- LLZKFuseProductLoopsPass.cpp ----------------------------*- C++ -*-===// // // Part of the LLZK Project, under the Apache License v2.0. // See LICENSE.txt for license information. diff --git a/lib/Transforms/LLZKLoweringUtils.cpp b/lib/Transforms/LLZKLoweringUtils.cpp index ee206bcd11..2dd88c9fa4 100644 --- a/lib/Transforms/LLZKLoweringUtils.cpp +++ b/lib/Transforms/LLZKLoweringUtils.cpp @@ -1,4 +1,4 @@ -//===-- LLZKLoweringUtils.cpp --------------------------------*- C++ -*----===// +//===-- LLZKLoweringUtils.cpp -----------------------------------*- C++ -*-===// // // Shared utility function implementations for LLZK lowering passes. // diff --git a/lib/Transforms/LLZKTransformationPassPipelines.cpp b/lib/Transforms/LLZKTransformationPassPipelines.cpp index 32e06e632d..1c87cfcb6f 100644 --- a/lib/Transforms/LLZKTransformationPassPipelines.cpp +++ b/lib/Transforms/LLZKTransformationPassPipelines.cpp @@ -12,64 +12,162 @@ /// //===----------------------------------------------------------------------===// -#include "llzk/Transforms/LLZKTransformationPasses.h" +#include "llzk/Transforms/LLZKTransformationPassPipelines.h" + +#include "llzk/Dialect/Array/Transforms/TransformationPasses.h" +#include "llzk/Dialect/POD/Transforms/TransformationPasses.h" #include #include #include +#include + using namespace mlir; namespace llzk { -struct FullPolyLoweringOptions : public PassPipelineOptions { - Option maxDegree { - *this, "max-degree", llvm::cl::desc("Maximum polynomial degree (must be ≥ 2)"), - llvm::cl::init(2) - }; -}; +//===----------------------------------------------------------------------===// +// Pipeline implementation. +//===----------------------------------------------------------------------===// + +namespace { + +template +inline std::unique_ptr createConfiguredPass(const NestedPassOptionT &options) { + return options.getValue().createPass(); +} + +void buildFullStructInliningPipelineImpl( + OpPassManager &pm, polymorphic::FlatteningPassOptions flattening, bool arrayToScalar, + bool podToScalar, std::unique_ptr inliningPass +) { + // default to `main-as-root` if unspecified to avoid leaving parameterized templates + // that cause the later struct inlining pass to crash + if (flattening.cleanupMode == polymorphic::FlatteningCleanupMode::Unspecified) { + flattening.cleanupMode = polymorphic::FlatteningCleanupMode::MainAsRoot; + } + pm.addPass(polymorphic::createFlatteningPass(flattening)); + + // Run array-to-scalar first because it can split arrays within a pod + // but pod-to-scalar cannot split pods within an array. + if (arrayToScalar) { + pm.addPass(array::createArrayToScalarPass()); + } + if (podToScalar) { + pm.addPass(pod::createPodToScalarPass()); + } + // Canonicalize to remove known-condition `scf.if` regions so struct inlining + // can link "@compute" calls to struct members. + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(std::move(inliningPass)); -void addRemoveUnnecessaryOpsAndDefsPipeline(OpPassManager &pm) { + // Remove struct and member definitions that are no longer used after inlining. + pm.addPass(createUnusedDeclarationEliminationPass( + UnusedDeclarationEliminationPassOptions {.removeStructs = true} + )); +} + +void buildFullPolyLoweringPipelineImpl( + OpPassManager &pm, polymorphic::FlatteningPassOptions flattening, bool arrayToScalar, + bool podToScalar, std::unique_ptr inliningPass, std::unique_ptr polyLoweringPass +) { + // 1. Struct flattening and inlining + buildFullStructInliningPipelineImpl( + pm, flattening, arrayToScalar, podToScalar, std::move(inliningPass) + ); + // 2. Degree lowering + pm.addPass(std::move(polyLoweringPass)); + // 3. Cleanup + buildRemoveUnnecessaryOpsAndDefsPipeline(pm); +} + +} // namespace + +void buildRemoveUnnecessaryOpsPipeline(mlir::OpPassManager &pm) { pm.addPass(createRedundantReadAndWriteEliminationPass()); pm.addPass(createRedundantOperationEliminationPass()); +} + +void buildRemoveUnnecessaryOpsAndDefsPipeline(mlir::OpPassManager &pm) { + buildRemoveUnnecessaryOpsPipeline(pm); pm.addPass(createUnusedDeclarationEliminationPass()); } +void buildProductProgramPipeline(OpPassManager &pm) { + pm.addPass(createComputeConstrainToProductPass()); + pm.addPass(createFuseProductLoopsPass()); +} + +void buildFullStructInliningPipeline(OpPassManager &pm, const FullStructInliningConfig &cfg) { + buildFullStructInliningPipelineImpl( + pm, cfg.flattening, cfg.arrayToScalar, cfg.podToScalar, + component::createInlineStructsPass(cfg.inlining) + ); +} + +void buildFullPolyLoweringPipeline(OpPassManager &pm, const FullPolyLoweringConfig &cfg) { + buildFullPolyLoweringPipelineImpl( + pm, cfg.structInlining.flattening, cfg.structInlining.arrayToScalar, + cfg.structInlining.podToScalar, + component::createInlineStructsPass(cfg.structInlining.inlining), + createPolyLoweringPass(cfg.polyLowering) + ); +} + +//===----------------------------------------------------------------------===// +// Pipeline registration. +//===----------------------------------------------------------------------===// + void registerTransformationPassPipelines() { PassPipelineRegistration<>( "llzk-remove-unnecessary-ops", "Remove unnecessary operations, such as redundant reads or repeated constraints", - [](OpPassManager &pm) { - pm.addPass(createRedundantReadAndWriteEliminationPass()); - pm.addPass(createRedundantOperationEliminationPass()); - } + buildRemoveUnnecessaryOpsPipeline ); PassPipelineRegistration<>( "llzk-remove-unnecessary-ops-and-defs", "Remove unnecessary operations, member definitions, and struct definitions", - [](OpPassManager &pm) { addRemoveUnnecessaryOpsAndDefsPipeline(pm); } + buildRemoveUnnecessaryOpsAndDefsPipeline ); - PassPipelineRegistration( - "llzk-full-poly-lowering", - "Lower already-flattened polynomial constraints to a given max degree, then remove " - "unnecessary operations and definitions.", - [](OpPassManager &pm, const FullPolyLoweringOptions &opts) { - // 1. Degree lowering - pm.addPass(createPolyLoweringPass(PolyLoweringPassOptions {.maxDegree = opts.maxDegree})); + PassPipelineRegistration<>( + "llzk-product-program", + "Convert @compute/@constrain functions to @product function and perform alignment", + buildProductProgramPipeline + ); - // 2. Cleanup - addRemoveUnnecessaryOpsAndDefsPipeline(pm); + PassPipelineRegistration( + "llzk-full-struct-inlining", + "Run flattening and inlining of all struct definitions into the `main` struct. This " + "pipeline uses the `main-as-root` cleanup mode in the flattening pass by default. It " + "is not recommended to override this cleanup mode because other cleanup modes may " + "leave behind parameterized templates that later cause `llzk-inline-structs` to crash.", + [](OpPassManager &pm, const FullStructInliningOptions &opts) { + auto flattening = opts.flattening.getValue().createOptions(); + buildFullStructInliningPipelineImpl( + pm, flattening->createPassOptions(), opts.arrayToScalar, opts.podToScalar, + createConfiguredPass(opts.inlining) + ); } ); - PassPipelineRegistration<>( - "llzk-product-program", - "Convert @compute/@constrain functions to @product function and perform alignment", - [](OpPassManager &pm) { - pm.addPass(createComputeConstrainToProductPass()); - pm.addPass(createFuseProductLoopsPass()); + PassPipelineRegistration( + "llzk-full-poly-lowering", + "Run flattening and inlining of all struct definitions into the `main` struct, then lower " + "polynomial constraints to a given max degree, and finally remove unnecessary operations and " + "definitions. This pipeline uses the `main-as-root` cleanup mode in the flattening pass by " + "default. It is not recommended to override this cleanup mode because other cleanup modes " + "may leave behind parameterized templates that later cause `llzk-inline-structs` to crash.", + [](OpPassManager &pm, const FullPolyLoweringOptions &opts) { + auto structInlining = opts.structInlining.getValue().createOptions(); + auto flattening = structInlining->flattening.getValue().createOptions(); + buildFullPolyLoweringPipelineImpl( + pm, flattening->createPassOptions(), structInlining->arrayToScalar, + structInlining->podToScalar, createConfiguredPass(structInlining->inlining), + createConfiguredPass(opts.polyLowering) + ); } ); } diff --git a/lib/Transforms/LLZKUnusedDeclarationEliminationPass.cpp b/lib/Transforms/LLZKUnusedDeclarationEliminationPass.cpp index 90c4cf31da..a40ef3b4bc 100644 --- a/lib/Transforms/LLZKUnusedDeclarationEliminationPass.cpp +++ b/lib/Transforms/LLZKUnusedDeclarationEliminationPass.cpp @@ -78,6 +78,7 @@ class PassImpl : public llzk::impl::UnusedDeclarationEliminationPassBase member ops DenseMap members; for (auto &[structDef, structSym] : ctx.structToSymbol) { - structDef.walk([&](MemberDefOp member) { - // We don't consider public members in the Main component for removal, - // as these are output values and removing them would result in modifying - // the overall circuit interface. - // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) - if (!structDef.isMainComponent() || !member.hasPublicAttr()) { + bool notMain = !structDef.isMainComponent(); + structDef.walk([notMain, &structSym, &members](MemberDefOp member) { + // We don't consider public members in the Main component for removal, as these are output + // values and removing them would result in modifying the overall circuit interface. + if (notMain || !member.hasPublicAttr()) { SymbolRefAttr memberSym = appendLeaf(structSym, FlatSymbolRefAttr::get(member.getSymNameAttr())); members[memberSym] = member; @@ -104,14 +104,11 @@ class PassImpl : public llzk::impl::UnusedDeclarationEliminationPassBasegetAttrs()) { - namedAttr.getValue().walk([&](TypeAttr typeAttr) { tryAddUse(typeAttr.getValue()); }); + namedAttr.getValue().walk([&tryAddUse](TypeAttr typeAttr) { + tryAddUse(typeAttr.getValue()); + }); } return WalkResult::advance(); @@ -194,7 +193,7 @@ class PassImpl : public llzk::impl::UnusedDeclarationEliminationPassBase unusedStructs; - auto updateUnusedStructs = [&]() { + auto updateUnusedStructs = [&usedBy, &unusedStructs]() { for (auto &[structDef, users] : usedBy) { if (users.empty() && !structDef.isMainComponent()) { unusedStructs.push_back(structDef); @@ -226,6 +225,27 @@ class PassImpl : public llzk::impl::UnusedDeclarationEliminationPassBase emptyModules; + + ModuleOp rootModOp = getOperation(); + rootModOp.walk([&](ModuleOp modOp) { + if (modOp == rootModOp) { + return; + } + Region ®ion = modOp.getBodyRegion(); + if (region.empty() || region.front().empty()) { // module has `SingleBlock` trait + emptyModules.push_back(modOp); + } + }); + + for (ModuleOp modOp : emptyModules) { + LLVM_DEBUG(llvm::dbgs() << "Removing empty module " << modOp.getName() << '\n'); + modOp->erase(); + } + } }; } // namespace diff --git a/nix/llzk.nix b/nix/llzk.nix index 1831b05702..3552765fc9 100644 --- a/nix/llzk.nix +++ b/nix/llzk.nix @@ -12,7 +12,7 @@ }: let - version = "2.1.2"; + version = "3.0.0"; in stdenv.mkDerivation { pname = "llzk-${lib.toLower cmakeBuildType}"; diff --git a/test/Transforms/InlineStructs/circom_subcmps5b.llzk b/test/Transforms/InlineStructs/circom_subcmps5b.llzk index b8c3aa832d..acb2bc8a56 100644 --- a/test/Transforms/InlineStructs/circom_subcmps5b.llzk +++ b/test/Transforms/InlineStructs/circom_subcmps5b.llzk @@ -1,9 +1,4 @@ -// RUN: llzk-opt -llzk-flatten -llzk-pod-to-scalar -canonicalize -llzk-inline-structs %s 2>&1 | FileCheck --enable-var-scope %s -// -// COM: the following passes should run before inlining structs: -// COM: - llzk-flatten: to instantiate templated structs and unroll loops -// COM: - llzk-pod-to-scalar: to convert aggregate pod types into scalar SSA values -// COM: - canonicalize: to remove known condition `scf.if` regions so struct inlining can link "@compute" calls to struct members +// RUN: llzk-opt -llzk-full-struct-inlining %s 2>&1 | FileCheck --enable-var-scope %s ////////////////////////////////////////////////////////////////////////////////// // template Nop() { @@ -90,21 +85,6 @@ module attributes {llzk.lang = "circom", llzk.main = !struct.type<@SubCmp::@SubC } // CHECK-LABEL: module attributes {llzk.lang = "circom", llzk.main = !struct.type<@SubCmp::@SubCmp<[]>>} { -// CHECK-NEXT: module @Nop { -// CHECK-NEXT: struct.def @Nop { -// CHECK-NEXT: struct.member @o : !felt.type<"bn128"> {llzk.pub} -// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128"> {function.arg_name = "i"}) -> !struct.type<@Nop::@Nop> attributes {function.allow_non_native_field_ops, function.allow_witness} { -// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Nop::@Nop> -// CHECK-NEXT: struct.writem %[[VAL_1]][@o] = %[[VAL_0]] : <@Nop::@Nop>, !felt.type<"bn128"> -// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Nop::@Nop> -// CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Nop::@Nop>, %[[VAL_3:[0-9a-zA-Z_\.]+]]: !felt.type<"bn128"> {function.arg_name = "i"}) attributes {function.allow_constraint, function.allow_non_native_field_ops} { -// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@o] : <@Nop::@Nop>, !felt.type<"bn128"> -// CHECK-NEXT: constrain.eq %[[VAL_4]], %[[VAL_3]] : !felt.type<"bn128">, !felt.type<"bn128"> -// CHECK-NEXT: function.return -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } // CHECK-NEXT: module @SubCmp { // CHECK-NEXT: struct.def @SubCmp { // CHECK-NEXT: struct.member @o : !felt.type<"bn128"> {llzk.pub} diff --git a/test/Transforms/InlineStructs/full_struct_inlining_cleanup_bug.llzk b/test/Transforms/InlineStructs/full_struct_inlining_cleanup_bug.llzk new file mode 100644 index 0000000000..1d75551030 --- /dev/null +++ b/test/Transforms/InlineStructs/full_struct_inlining_cleanup_bug.llzk @@ -0,0 +1,88 @@ +// These pipelines default their flattening stage to `cleanup=main-as-root`, which +// avoids leaving unreachable partially instantiated templates behind. + +// RUN: llzk-opt -llzk-full-struct-inlining %s | FileCheck %s +// RUN: llzk-opt -llzk-full-poly-lowering %s | FileCheck %s + +// Explicitly overriding that cleanup back to `preimage` in the underlying flatten+inline +// sequence reproduces the old error (checked by the `-verify-diagnostics` flag). + +// RUN: llzk-opt --pass-pipeline='builtin.module(llzk-flatten{cleanup=preimage},llzk-inline-structs)' -verify-diagnostics %s + +module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { + // expected-error@+1 {{Cannot inline struct within a template. Run `llzk-flatten` to instantiate templated structs.}} + poly.template @TComponentY { + poly.param @A + poly.param @B + + struct.def @ComponentY { + function.def @compute() -> !struct.type<@TComponentY::@ComponentY<[@A, @B]>> { + %self = struct.new : !struct.type<@TComponentY::@ComponentY<[@A, @B]>> + function.return %self : !struct.type<@TComponentY::@ComponentY<[@A, @B]>> + } + + function.def @constrain(%self: !struct.type<@TComponentY::@ComponentY<[@A, @B]>>) { + function.return + } + } + } + + poly.template @TComponentX { + poly.param @C + + struct.def @ComponentX { + struct.member @f2 : !struct.type<@TComponentY::@ComponentY<[5, @C]>> + + function.def @compute() -> !struct.type<@TComponentX::@ComponentX<[@C]>> { + %self = struct.new : !struct.type<@TComponentX::@ComponentX<[@C]>> + %x = function.call @TComponentY::@ComponentY::@compute() + : () -> !struct.type<@TComponentY::@ComponentY<[5, @C]>> + struct.writem %self[@f2] = %x + : !struct.type<@TComponentX::@ComponentX<[@C]>>, + !struct.type<@TComponentY::@ComponentY<[5, @C]>> + function.return %self : !struct.type<@TComponentX::@ComponentX<[@C]>> + } + + function.def @constrain(%self: !struct.type<@TComponentX::@ComponentX<[@C]>>) { + %b = struct.readm %self[@f2] + : !struct.type<@TComponentX::@ComponentX<[@C]>>, + !struct.type<@TComponentY::@ComponentY<[5, @C]>> + function.call @TComponentY::@ComponentY::@constrain(%b) + : (!struct.type<@TComponentY::@ComponentY<[5, @C]>>) -> () + function.return + } + } + } + + struct.def @Main { + struct.member @f3 : !struct.type<@TComponentX::@ComponentX<[43]>> + + function.def @compute() -> !struct.type<@Main> { + %self = struct.new : !struct.type<@Main> + %x = function.call @TComponentX::@ComponentX::@compute() + : () -> !struct.type<@TComponentX::@ComponentX<[43]>> + struct.writem %self[@f3] = %x + : !struct.type<@Main>, !struct.type<@TComponentX::@ComponentX<[43]>> + function.return %self : !struct.type<@Main> + } + + function.def @constrain(%self: !struct.type<@Main>) { + %b = struct.readm %self[@f3] + : !struct.type<@Main>, !struct.type<@TComponentX::@ComponentX<[43]>> + function.call @TComponentX::@ComponentX::@constrain(%b) + : (!struct.type<@TComponentX::@ComponentX<[43]>>) -> () + function.return + } + } +} +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Main { +// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// CHECK-NEXT: function.return %[[VAL_0]] : !struct.type<@Main> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/InlineStructs/inline_structs_max_complexity.llzk b/test/Transforms/InlineStructs/inline_structs_max_complexity.llzk index f8d9838404..5f99fd55fb 100644 --- a/test/Transforms/InlineStructs/inline_structs_max_complexity.llzk +++ b/test/Transforms/InlineStructs/inline_structs_max_complexity.llzk @@ -1,4 +1,4 @@ -// RUN: llzk-opt -split-input-file --pass-pipeline='builtin.module(llzk-flatten,llzk-inline-structs{max-merge-complexity=2})' -verify-diagnostics %s | FileCheck %s +// RUN: llzk-opt -split-input-file -llzk-full-struct-inlining="array-to-scalar=0 pod-to-scalar=false inlining={max-merge-complexity=2}" -verify-diagnostics %s | FileCheck %s // TESTS: full inlining is within `maxComplexity` limit module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { @@ -58,18 +58,19 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @Main { +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Main { // CHECK-NEXT: struct.member @"f:!s<@Component1B>+f2:!s<@Component1A>+f1" : !felt.type // CHECK-NEXT: struct.member @"f:!s<@Component1B>+f2:!s<@Component1A>+f2" : !felt.type // CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = felt.const 42 +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = felt.const 42 // CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> // CHECK-NEXT: struct.writem %[[VAL_1]][@"f:!s<@Component1B>+f2:!s<@Component1A>+f1"] = %[[VAL_0]] : <@Main>, !felt.type // CHECK-NEXT: struct.writem %[[VAL_1]][@"f:!s<@Component1B>+f2:!s<@Component1A>+f2"] = %[[VAL_0]] : <@Main>, !felt.type // CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Main> // CHECK-NEXT: } // CHECK-NEXT: function.def @constrain(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.const 42 +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.const 42 // CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@"f:!s<@Component1B>+f2:!s<@Component1A>+f1"] : <@Main>, !felt.type // CHECK-NEXT: constrain.eq %[[VAL_4]], %[[VAL_3]] : !felt.type, !felt.type // CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@"f:!s<@Component1B>+f2:!s<@Component1A>+f2"] : <@Main>, !felt.type @@ -77,6 +78,7 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { // CHECK-NEXT: function.return // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK-NEXT: } // ----- // TESTS: A->B inlining succeeds but `maxComplexity` prevents inlining B->Main @@ -138,7 +140,8 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @Component2B { +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Component2B { // CHECK-NEXT: struct.member @"f2:!s<@Component2A>+f1" : !felt.type // CHECK-NEXT: struct.member @"f2:!s<@Component2A>+f2" : !felt.type // CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@Component2B> attributes {function.allow_witness} { @@ -155,24 +158,24 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { // CHECK-NEXT: function.return // CHECK-NEXT: } // CHECK-NEXT: } -// -// CHECK-LABEL: struct.def @Main { +// CHECK-NEXT: struct.def @Main { // CHECK-NEXT: struct.member @f : !struct.type<@Component2B> // CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = felt.const 42 -// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> -// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = function.call @Component2B::@compute(%[[VAL_0]]) : (!felt.type) -> !struct.type<@Component2B> -// CHECK-NEXT: struct.writem %[[VAL_1]][@f] = %[[VAL_2]] : <@Main>, !struct.type<@Component2B> -// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Main> +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = felt.const 42 +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = function.call @Component2B::@compute(%[[VAL_6]]) : (!felt.type) -> !struct.type<@Component2B> +// CHECK-NEXT: struct.writem %[[VAL_7]][@f] = %[[VAL_8]] : <@Main>, !struct.type<@Component2B> +// CHECK-NEXT: function.return %[[VAL_7]] : !struct.type<@Main> // CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[VAL_3:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.const 42 -// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_3]][@f] : <@Main>, !struct.type<@Component2B> -// CHECK-NEXT: function.call @Component2B::@constrain(%[[VAL_5]], %[[VAL_4]]) : (!struct.type<@Component2B>, !felt.type) -> () -// CHECK-NEXT: constrain.eq %[[VAL_4]], %[[VAL_4]] : !felt.type, !felt.type +// CHECK-NEXT: function.def @constrain(%[[VAL_9:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.const 42 +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_9]][@f] : <@Main>, !struct.type<@Component2B> +// CHECK-NEXT: function.call @Component2B::@constrain(%[[VAL_11]], %[[VAL_10]]) : (!struct.type<@Component2B>, !felt.type) -> () +// CHECK-NEXT: constrain.eq %[[VAL_10]], %[[VAL_10]] : !felt.type, !felt.type // CHECK-NEXT: function.return // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK-NEXT: } // ----- // TESTS: `maxComplexity` prevents A->B inlining but B->Main inlining succeeds @@ -235,7 +238,8 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @Component3A { +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Component3A { // CHECK-NEXT: struct.member @f1 : !felt.type // CHECK-NEXT: struct.member @f2 : !felt.type // CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@Component3A> attributes {function.allow_witness} { @@ -253,24 +257,24 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { // CHECK-NEXT: function.return // CHECK-NEXT: } // CHECK-NEXT: } -// -// CHECK-LABEL: struct.def @Main { +// CHECK-NEXT: struct.def @Main { // CHECK-NEXT: struct.member @"f:!s<@Component3B>+f2" : !struct.type<@Component3A> // CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = felt.const 42 -// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> -// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = function.call @Component3A::@compute(%[[VAL_0]]) : (!felt.type) -> !struct.type<@Component3A> -// CHECK-NEXT: struct.writem %[[VAL_1]][@"f:!s<@Component3B>+f2"] = %[[VAL_2]] : <@Main>, !struct.type<@Component3A> -// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Main> +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = felt.const 42 +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = function.call @Component3A::@compute(%[[VAL_6]]) : (!felt.type) -> !struct.type<@Component3A> +// CHECK-NEXT: struct.writem %[[VAL_7]][@"f:!s<@Component3B>+f2"] = %[[VAL_8]] : <@Main>, !struct.type<@Component3A> +// CHECK-NEXT: function.return %[[VAL_7]] : !struct.type<@Main> // CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[VAL_3:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.const 42 -// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_3]][@"f:!s<@Component3B>+f2"] : <@Main>, !struct.type<@Component3A> -// CHECK-NEXT: function.call @Component3A::@constrain(%[[VAL_5]], %[[VAL_4]]) : (!struct.type<@Component3A>, !felt.type) -> () -// CHECK-NEXT: constrain.eq %[[VAL_4]], %[[VAL_4]] : !felt.type, !felt.type +// CHECK-NEXT: function.def @constrain(%[[VAL_9:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.const 42 +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_9]][@"f:!s<@Component3B>+f2"] : <@Main>, !struct.type<@Component3A> +// CHECK-NEXT: function.call @Component3A::@constrain(%[[VAL_11]], %[[VAL_10]]) : (!struct.type<@Component3A>, !felt.type) -> () +// CHECK-NEXT: constrain.eq %[[VAL_10]], %[[VAL_10]] : !felt.type, !felt.type // CHECK-NEXT: function.return // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK-NEXT: } // ----- // TESTS: `maxComplexity` prevents some inlining for multiple callees inlined into the same caller @@ -335,7 +339,8 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @Component4A { +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Component4A { // CHECK-NEXT: struct.member @f1 : !felt.type // CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@Component4A> attributes {function.allow_witness} { // CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@Component4A> @@ -350,27 +355,27 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { // CHECK-NEXT: function.return // CHECK-NEXT: } // CHECK-NEXT: } -// -// CHECK-LABEL: struct.def @Main { +// CHECK-NEXT: struct.def @Main { // CHECK-NEXT: struct.member @"f1:!s<@Component4B>+f1" : !felt.type // CHECK-NEXT: struct.member @f2 : !struct.type<@Component4A> // CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = felt.const 123 -// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> -// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = function.call @Component4A::@compute(%[[VAL_0]], %[[VAL_0]]) : (!felt.type, !felt.type) -> !struct.type<@Component4A> -// CHECK-NEXT: struct.writem %[[VAL_1]][@"f1:!s<@Component4B>+f1"] = %[[VAL_0]] : <@Main>, !felt.type -// CHECK-NEXT: struct.writem %[[VAL_1]][@f2] = %[[VAL_2]] : <@Main>, !struct.type<@Component4A> -// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Main> +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = felt.const 123 +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = function.call @Component4A::@compute(%[[VAL_9]], %[[VAL_9]]) : (!felt.type, !felt.type) -> !struct.type<@Component4A> +// CHECK-NEXT: struct.writem %[[VAL_10]][@"f1:!s<@Component4B>+f1"] = %[[VAL_9]] : <@Main>, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_10]][@f2] = %[[VAL_11]] : <@Main>, !struct.type<@Component4A> +// CHECK-NEXT: function.return %[[VAL_10]] : !struct.type<@Main> // CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[VAL_3:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.const 123 -// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_3]][@f2] : <@Main>, !struct.type<@Component4A> -// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_3]][@"f1:!s<@Component4B>+f1"] : <@Main>, !felt.type -// CHECK-NEXT: constrain.eq %[[VAL_4]], %[[VAL_6]] : !felt.type, !felt.type -// CHECK-NEXT: function.call @Component4A::@constrain(%[[VAL_5]], %[[VAL_4]], %[[VAL_4]]) : (!struct.type<@Component4A>, !felt.type, !felt.type) -> () +// CHECK-NEXT: function.def @constrain(%[[VAL_12:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = felt.const 123 +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_12]][@f2] : <@Main>, !struct.type<@Component4A> +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_12]][@"f1:!s<@Component4B>+f1"] : <@Main>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_13]], %[[VAL_15]] : !felt.type, !felt.type +// CHECK-NEXT: function.call @Component4A::@constrain(%[[VAL_14]], %[[VAL_13]], %[[VAL_13]]) : (!struct.type<@Component4A>, !felt.type, !felt.type) -> () // CHECK-NEXT: function.return // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK-NEXT: } // ----- // TESTS: full inlining is within `maxComplexity` limit (multiple calls from same struct) @@ -417,21 +422,8 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @Component5B { -// CHECK-NEXT: struct.member @f1 : !felt.type -// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@Component5B> attributes {function.allow_witness} { -// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Component5B> -// CHECK-NEXT: struct.writem %[[VAL_1]][@f1] = %[[VAL_0]] : <@Component5B>, !felt.type -// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Component5B> -// CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Component5B>, %[[VAL_3:[0-9a-zA-Z_\.]+]]: !felt.type) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@f1] : <@Component5B>, !felt.type -// CHECK-NEXT: constrain.eq %[[VAL_3]], %[[VAL_4]] : !felt.type, !felt.type -// CHECK-NEXT: function.return -// CHECK-NEXT: } -// CHECK-NEXT: } -// -// CHECK-LABEL: struct.def @Main { +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Main { // CHECK-NEXT: struct.member @"f1:!s<@Component5B>+f1" : !felt.type // CHECK-NEXT: struct.member @"f2:!s<@Component5B>+f1" : !felt.type // CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { @@ -450,6 +442,7 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { // CHECK-NEXT: function.return // CHECK-NEXT: } // CHECK-NEXT: } +// CHECK-NEXT: } // ----- // TESTS: 'max-merge-complexity' prevents inlining A to Main but the others are inlined. @@ -539,50 +532,51 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @Component6A { -// CHECK-NEXT: struct.member @x : !felt.type -// CHECK-NEXT: struct.member @y : !felt.type -// CHECK-NEXT: struct.member @z : !felt.type -// CHECK-NEXT: function.def @compute() -> !struct.type<@Component6A> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V0:[0-9a-zA-Z_\.]+]] = felt.const 12 -// CHECK-NEXT: %[[V1:[0-9a-zA-Z_\.]+]] = struct.new : <@Component6A> -// CHECK-NEXT: struct.writem %[[V1]][@x] = %[[V0]] : <@Component6A>, !felt.type -// CHECK-NEXT: struct.writem %[[V1]][@y] = %[[V0]] : <@Component6A>, !felt.type -// CHECK-NEXT: struct.writem %[[V1]][@z] = %[[V0]] : <@Component6A>, !felt.type -// CHECK-NEXT: function.return %[[V1]] : !struct.type<@Component6A> -// CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V2:[0-9a-zA-Z_\.]+]]: !struct.type<@Component6A>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[V3:[0-9a-zA-Z_\.]+]] = felt.const 12 -// CHECK-NEXT: %[[V4:[0-9a-zA-Z_\.]+]] = struct.readm %[[V2]][@x] : <@Component6A>, !felt.type -// CHECK-NEXT: constrain.eq %[[V4]], %[[V3]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V5:[0-9a-zA-Z_\.]+]] = struct.readm %[[V2]][@y] : <@Component6A>, !felt.type -// CHECK-NEXT: constrain.eq %[[V5]], %[[V3]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V6:[0-9a-zA-Z_\.]+]] = struct.readm %[[V2]][@z] : <@Component6A>, !felt.type -// CHECK-NEXT: constrain.eq %[[V6]], %[[V3]] : !felt.type, !felt.type -// CHECK-NEXT: function.return -// CHECK-NEXT: } -// CHECK-NEXT: } -// -// CHECK-LABEL: struct.def @Main { -// CHECK-NEXT: struct.member @fa : !struct.type<@Component6A> -// CHECK-NEXT: struct.member @"fc:!s<@Component6C>+fb:!s<@Component6B>+f" : !felt.type -// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V0:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> -// CHECK-NEXT: %[[V1:[0-9a-zA-Z_\.]+]] = function.call @Component6A::@compute() : () -> !struct.type<@Component6A> -// CHECK-NEXT: struct.writem %[[V0]][@fa] = %[[V1]] : <@Main>, !struct.type<@Component6A> -// CHECK-NEXT: %[[V2:[0-9a-zA-Z_\.]+]] = struct.readm %[[V1]][@x] : <@Component6A>, !felt.type -// CHECK-NEXT: %[[V3:[0-9a-zA-Z_\.]+]] = felt.mul %[[V2]], %[[V2]] : !felt.type, !felt.type -// CHECK-NEXT: struct.writem %[[V0]][@"fc:!s<@Component6C>+fb:!s<@Component6B>+f"] = %[[V3]] : <@Main>, !felt.type -// CHECK-NEXT: function.return %[[V0]] : !struct.type<@Main> +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Component6A { +// CHECK-NEXT: struct.member @x : !felt.type {llzk.pub} +// CHECK-NEXT: struct.member @y : !felt.type {llzk.pub} +// CHECK-NEXT: struct.member @z : !felt.type {llzk.pub} +// CHECK-NEXT: function.def @compute() -> !struct.type<@Component6A> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = felt.const 12 +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Component6A> +// CHECK-NEXT: struct.writem %[[VAL_1]][@x] = %[[VAL_0]] : <@Component6A>, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_1]][@y] = %[[VAL_0]] : <@Component6A>, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_1]][@z] = %[[VAL_0]] : <@Component6A>, !felt.type +// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Component6A> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Component6A>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.const 12 +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@x] : <@Component6A>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_4]], %[[VAL_3]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@y] : <@Component6A>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_5]], %[[VAL_3]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@z] : <@Component6A>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_6]], %[[VAL_3]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V4:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[V5:[0-9a-zA-Z_\.]+]] = struct.readm %[[V4]][@fa] : <@Main>, !struct.type<@Component6A> -// CHECK-NEXT: function.call @Component6A::@constrain(%[[V5]]) : (!struct.type<@Component6A>) -> () -// CHECK-NEXT: %[[V6:[0-9a-zA-Z_\.]+]] = struct.readm %[[V5]][@x] : <@Component6A>, !felt.type -// CHECK-NEXT: %[[V7:[0-9a-zA-Z_\.]+]] = felt.mul %[[V6]], %[[V6]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V8:[0-9a-zA-Z_\.]+]] = struct.readm %[[V4]][@"fc:!s<@Component6C>+fb:!s<@Component6B>+f"] : <@Main>, !felt.type -// CHECK-NEXT: constrain.eq %[[V7]], %[[V8]] : !felt.type, !felt.type -// CHECK-NEXT: function.return +// CHECK-NEXT: struct.def @Main { +// CHECK-NEXT: struct.member @fa : !struct.type<@Component6A> +// CHECK-NEXT: struct.member @"fc:!s<@Component6C>+fb:!s<@Component6B>+f" : !felt.type +// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = function.call @Component6A::@compute() : () -> !struct.type<@Component6A> +// CHECK-NEXT: struct.writem %[[VAL_7]][@fa] = %[[VAL_8]] : <@Main>, !struct.type<@Component6A> +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_8]][@x] : <@Component6A>, !felt.type +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_9]], %[[VAL_9]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_7]][@"fc:!s<@Component6C>+fb:!s<@Component6B>+f"] = %[[VAL_10]] : <@Main>, !felt.type +// CHECK-NEXT: function.return %[[VAL_7]] : !struct.type<@Main> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_11:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_11]][@fa] : <@Main>, !struct.type<@Component6A> +// CHECK-NEXT: function.call @Component6A::@constrain(%[[VAL_12]]) : (!struct.type<@Component6A>) -> () +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_12]][@x] : <@Component6A>, !felt.type +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_13]], %[[VAL_13]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_11]][@"fc:!s<@Component6C>+fb:!s<@Component6B>+f"] : <@Main>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_14]], %[[VAL_15]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // ----- @@ -671,50 +665,51 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @Component7C { -// CHECK-NEXT: struct.member @"fb:!s<@Component7B>+f" : !felt.type -// CHECK-NEXT: function.def @compute(%[[V0:[0-9a-zA-Z_\.]+]]: !felt.type, %[[V1:[0-9a-zA-Z_\.]+]]: !felt.type, %[[V2:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@Component7C> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V3:[0-9a-zA-Z_\.]+]] = struct.new : <@Component7C> -// CHECK-NEXT: %[[V4:[0-9a-zA-Z_\.]+]] = felt.mul %[[V2]], %[[V2]] : !felt.type, !felt.type -// CHECK-NEXT: struct.writem %[[V3]][@"fb:!s<@Component7B>+f"] = %[[V4]] : <@Component7C>, !felt.type -// CHECK-NEXT: function.return %[[V3]] : !struct.type<@Component7C> -// CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V5:[0-9a-zA-Z_\.]+]]: !struct.type<@Component7C>, %[[V6:[0-9a-zA-Z_\.]+]]: !felt.type, %[[V7:[0-9a-zA-Z_\.]+]]: !felt.type, %[[V8:[0-9a-zA-Z_\.]+]]: !felt.type) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[V9:[0-9a-zA-Z_\.]+]] = felt.mul %[[V8]], %[[V8]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V10:[0-9a-zA-Z_\.]+]] = struct.readm %[[V5]][@"fb:!s<@Component7B>+f"] : <@Component7C>, !felt.type -// CHECK-NEXT: constrain.eq %[[V9]], %[[V10]] : !felt.type, !felt.type -// CHECK-NEXT: function.return -// CHECK-NEXT: } -// CHECK-NEXT: } -// -// CHECK-LABEL: struct.def @Main { -// CHECK-NEXT: struct.member @"fa:!s<@Component7A>+x" : !felt.type -// CHECK-NEXT: struct.member @"fa:!s<@Component7A>+y" : !felt.type -// CHECK-NEXT: struct.member @"fa:!s<@Component7A>+z" : !felt.type -// CHECK-NEXT: struct.member @fc : !struct.type<@Component7C> -// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V0:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> -// CHECK-NEXT: %[[V1:[0-9a-zA-Z_\.]+]] = felt.const 12 -// CHECK-NEXT: struct.writem %[[V0]][@"fa:!s<@Component7A>+x"] = %[[V1]] : <@Main>, !felt.type -// CHECK-NEXT: struct.writem %[[V0]][@"fa:!s<@Component7A>+y"] = %[[V1]] : <@Main>, !felt.type -// CHECK-NEXT: struct.writem %[[V0]][@"fa:!s<@Component7A>+z"] = %[[V1]] : <@Main>, !felt.type -// CHECK-NEXT: %[[V2:[0-9a-zA-Z_\.]+]] = struct.readm %[[V0]][@"fa:!s<@Component7A>+x"] : <@Main>, !felt.type -// CHECK-NEXT: %[[V3:[0-9a-zA-Z_\.]+]] = struct.readm %[[V0]][@"fa:!s<@Component7A>+y"] : <@Main>, !felt.type -// CHECK-NEXT: %[[V4:[0-9a-zA-Z_\.]+]] = struct.readm %[[V0]][@"fa:!s<@Component7A>+z"] : <@Main>, !felt.type -// CHECK-NEXT: %[[V5:[0-9a-zA-Z_\.]+]] = function.call @Component7C::@compute(%[[V2]], %[[V3]], %[[V4]]) : (!felt.type, !felt.type, !felt.type) -> !struct.type<@Component7C> -// CHECK-NEXT: struct.writem %[[V0]][@fc] = %[[V5]] : <@Main>, !struct.type<@Component7C> -// CHECK-NEXT: function.return %[[V0]] : !struct.type<@Main> +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Component7C { +// CHECK-NEXT: struct.member @"fb:!s<@Component7B>+f" : !felt.type +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_2:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@Component7C> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = struct.new : <@Component7C> +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_2]], %[[VAL_2]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_3]][@"fb:!s<@Component7B>+f"] = %[[VAL_4]] : <@Component7C>, !felt.type +// CHECK-NEXT: function.return %[[VAL_3]] : !struct.type<@Component7C> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_5:[0-9a-zA-Z_\.]+]]: !struct.type<@Component7C>, %[[VAL_6:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_7:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_8:[0-9a-zA-Z_\.]+]]: !felt.type) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_8]], %[[VAL_8]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_5]][@"fb:!s<@Component7B>+f"] : <@Component7C>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V6:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[V7:[0-9a-zA-Z_\.]+]] = felt.const 12 -// CHECK-NEXT: %[[V8:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@"fa:!s<@Component7A>+x"] : <@Main>, !felt.type -// CHECK-NEXT: constrain.eq %[[V8]], %[[V7]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V9:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@fc] : <@Main>, !struct.type<@Component7C> -// CHECK-NEXT: %[[V10:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@"fa:!s<@Component7A>+x"] : <@Main>, !felt.type -// CHECK-NEXT: %[[V11:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@"fa:!s<@Component7A>+y"] : <@Main>, !felt.type -// CHECK-NEXT: %[[V12:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@"fa:!s<@Component7A>+z"] : <@Main>, !felt.type -// CHECK-NEXT: function.call @Component7C::@constrain(%[[V9]], %[[V10]], %[[V11]], %[[V12]]) : (!struct.type<@Component7C>, !felt.type, !felt.type, !felt.type) -> () -// CHECK-NEXT: function.return +// CHECK-NEXT: struct.def @Main { +// CHECK-NEXT: struct.member @"fa:!s<@Component7A>+x" : !felt.type {llzk.pub} +// CHECK-NEXT: struct.member @"fa:!s<@Component7A>+y" : !felt.type {llzk.pub} +// CHECK-NEXT: struct.member @"fa:!s<@Component7A>+z" : !felt.type {llzk.pub} +// CHECK-NEXT: struct.member @fc : !struct.type<@Component7C> +// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = felt.const 12 +// CHECK-NEXT: struct.writem %[[VAL_11]][@"fa:!s<@Component7A>+x"] = %[[VAL_12]] : <@Main>, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_11]][@"fa:!s<@Component7A>+y"] = %[[VAL_12]] : <@Main>, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_11]][@"fa:!s<@Component7A>+z"] = %[[VAL_12]] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_11]][@"fa:!s<@Component7A>+x"] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_11]][@"fa:!s<@Component7A>+y"] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_11]][@"fa:!s<@Component7A>+z"] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = function.call @Component7C::@compute(%[[VAL_13]], %[[VAL_14]], %[[VAL_15]]) : (!felt.type, !felt.type, !felt.type) -> !struct.type<@Component7C> +// CHECK-NEXT: struct.writem %[[VAL_11]][@fc] = %[[VAL_16]] : <@Main>, !struct.type<@Component7C> +// CHECK-NEXT: function.return %[[VAL_11]] : !struct.type<@Main> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_17:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = felt.const 12 +// CHECK-NEXT: %[[VAL_19:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@"fa:!s<@Component7A>+x"] : <@Main>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_19]], %[[VAL_18]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@fc] : <@Main>, !struct.type<@Component7C> +// CHECK-NEXT: %[[VAL_21:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@"fa:!s<@Component7A>+x"] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@"fa:!s<@Component7A>+y"] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@"fa:!s<@Component7A>+z"] : <@Main>, !felt.type +// CHECK-NEXT: function.call @Component7C::@constrain(%[[VAL_20]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]]) : (!struct.type<@Component7C>, !felt.type, !felt.type, !felt.type) -> () +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // ----- @@ -811,69 +806,69 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @Component8A { -// CHECK-NEXT: struct.member @x : !felt.type -// CHECK-NEXT: struct.member @y : !felt.type -// CHECK-NEXT: struct.member @z : !felt.type -// CHECK-NEXT: function.def @compute() -> !struct.type<@Component8A> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V0:[0-9a-zA-Z_\.]+]] = felt.const 12 -// CHECK-NEXT: %[[V1:[0-9a-zA-Z_\.]+]] = struct.new : <@Component8A> -// CHECK-NEXT: struct.writem %[[V1]][@x] = %[[V0]] : <@Component8A>, !felt.type -// CHECK-NEXT: struct.writem %[[V1]][@y] = %[[V0]] : <@Component8A>, !felt.type -// CHECK-NEXT: struct.writem %[[V1]][@z] = %[[V0]] : <@Component8A>, !felt.type -// CHECK-NEXT: function.return %[[V1]] : !struct.type<@Component8A> -// CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V2:[0-9a-zA-Z_\.]+]]: !struct.type<@Component8A>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[V3:[0-9a-zA-Z_\.]+]] = felt.const 12 -// CHECK-NEXT: %[[V4:[0-9a-zA-Z_\.]+]] = struct.readm %[[V2]][@x] : <@Component8A>, !felt.type -// CHECK-NEXT: constrain.eq %[[V4]], %[[V3]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V5:[0-9a-zA-Z_\.]+]] = struct.readm %[[V2]][@y] : <@Component8A>, !felt.type -// CHECK-NEXT: constrain.eq %[[V5]], %[[V3]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V6:[0-9a-zA-Z_\.]+]] = struct.readm %[[V2]][@z] : <@Component8A>, !felt.type -// CHECK-NEXT: constrain.eq %[[V6]], %[[V3]] : !felt.type, !felt.type -// CHECK-NEXT: function.return -// CHECK-NEXT: } -// CHECK-NEXT: } -// -// CHECK-LABEL: struct.def @Component8B { -// CHECK-NEXT: struct.member @f : !felt.type -// CHECK-NEXT: function.def @compute(%[[V0:[0-9a-zA-Z_\.]+]]: !struct.type<@Component8A>) -> !struct.type<@Component8B> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V1:[0-9a-zA-Z_\.]+]] = struct.new : <@Component8B> -// CHECK-NEXT: %[[V2:[0-9a-zA-Z_\.]+]] = struct.readm %[[V0]][@z] : <@Component8A>, !felt.type -// CHECK-NEXT: %[[V3:[0-9a-zA-Z_\.]+]] = felt.mul %[[V2]], %[[V2]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V4:[0-9a-zA-Z_\.]+]] = felt.mul %[[V3]], %[[V3]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V5:[0-9a-zA-Z_\.]+]] = felt.mul %[[V4]], %[[V4]] : !felt.type, !felt.type -// CHECK-NEXT: struct.writem %[[V1]][@f] = %[[V5]] : <@Component8B>, !felt.type -// CHECK-NEXT: function.return %[[V1]] : !struct.type<@Component8B> -// CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V6:[0-9a-zA-Z_\.]+]]: !struct.type<@Component8B>, %[[V7:[0-9a-zA-Z_\.]+]]: !struct.type<@Component8A>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[V8:[0-9a-zA-Z_\.]+]] = struct.readm %[[V7]][@z] : <@Component8A>, !felt.type -// CHECK-NEXT: %[[V9:[0-9a-zA-Z_\.]+]] = felt.mul %[[V8]], %[[V8]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V10:[0-9a-zA-Z_\.]+]] = felt.mul %[[V9]], %[[V9]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V11:[0-9a-zA-Z_\.]+]] = felt.mul %[[V10]], %[[V10]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V12:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@f] : <@Component8B>, !felt.type -// CHECK-NEXT: constrain.eq %[[V11]], %[[V12]] : !felt.type, !felt.type -// CHECK-NEXT: function.return +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Component8A { +// CHECK-NEXT: struct.member @x : !felt.type {llzk.pub} +// CHECK-NEXT: struct.member @y : !felt.type {llzk.pub} +// CHECK-NEXT: struct.member @z : !felt.type {llzk.pub} +// CHECK-NEXT: function.def @compute() -> !struct.type<@Component8A> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = felt.const 12 +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = struct.new : <@Component8A> +// CHECK-NEXT: struct.writem %[[VAL_1]][@x] = %[[VAL_0]] : <@Component8A>, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_1]][@y] = %[[VAL_0]] : <@Component8A>, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_1]][@z] = %[[VAL_0]] : <@Component8A>, !felt.type +// CHECK-NEXT: function.return %[[VAL_1]] : !struct.type<@Component8A> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Component8A>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.const 12 +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@x] : <@Component8A>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_4]], %[[VAL_3]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@y] : <@Component8A>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_5]], %[[VAL_3]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_2]][@z] : <@Component8A>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_6]], %[[VAL_3]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: } -// -// CHECK-LABEL: struct.def @Main { -// CHECK-NEXT: struct.member @fa : !struct.type<@Component8A> -// CHECK-NEXT: struct.member @"fc:!s<@Component8C>+fb" : !struct.type<@Component8B> -// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V0:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> -// CHECK-NEXT: %[[V1:[0-9a-zA-Z_\.]+]] = function.call @Component8A::@compute() : () -> !struct.type<@Component8A> -// CHECK-NEXT: struct.writem %[[V0]][@fa] = %[[V1]] : <@Main>, !struct.type<@Component8A> -// CHECK-NEXT: %[[V2:[0-9a-zA-Z_\.]+]] = function.call @Component8B::@compute(%[[V1]]) : (!struct.type<@Component8A>) -> !struct.type<@Component8B> -// CHECK-NEXT: struct.writem %[[V0]][@"fc:!s<@Component8C>+fb"] = %[[V2]] : <@Main>, !struct.type<@Component8B> -// CHECK-NEXT: function.return %[[V0]] : !struct.type<@Main> +// CHECK-NEXT: struct.def @Component8B { +// CHECK-NEXT: struct.member @f : !felt.type +// CHECK-NEXT: function.def @compute(%[[VAL_7:[0-9a-zA-Z_\.]+]]: !struct.type<@Component8A>) -> !struct.type<@Component8B> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = struct.new : <@Component8B> +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_7]][@z] : <@Component8A>, !felt.type +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_9]], %[[VAL_9]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_10]], %[[VAL_10]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_11]], %[[VAL_11]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_8]][@f] = %[[VAL_12]] : <@Component8B>, !felt.type +// CHECK-NEXT: function.return %[[VAL_8]] : !struct.type<@Component8B> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_13:[0-9a-zA-Z_\.]+]]: !struct.type<@Component8B>, %[[VAL_14:[0-9a-zA-Z_\.]+]]: !struct.type<@Component8A>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_14]][@z] : <@Component8A>, !felt.type +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_15]], %[[VAL_15]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_17:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_16]], %[[VAL_16]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_17]], %[[VAL_17]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_19:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_13]][@f] : <@Component8B>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_18]], %[[VAL_19]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V3:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[V4:[0-9a-zA-Z_\.]+]] = struct.readm %[[V3]][@fa] : <@Main>, !struct.type<@Component8A> -// CHECK-NEXT: function.call @Component8A::@constrain(%[[V4]]) : (!struct.type<@Component8A>) -> () -// CHECK-NEXT: %[[V5:[0-9a-zA-Z_\.]+]] = struct.readm %[[V3]][@"fc:!s<@Component8C>+fb"] : <@Main>, !struct.type<@Component8B> -// CHECK-NEXT: function.call @Component8B::@constrain(%[[V5]], %[[V4]]) : (!struct.type<@Component8B>, !struct.type<@Component8A>) -> () -// CHECK-NEXT: function.return +// CHECK-NEXT: struct.def @Main { +// CHECK-NEXT: struct.member @fa : !struct.type<@Component8A> +// CHECK-NEXT: struct.member @"fc:!s<@Component8C>+fb" : !struct.type<@Component8B> +// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// CHECK-NEXT: %[[VAL_21:[0-9a-zA-Z_\.]+]] = function.call @Component8A::@compute() : () -> !struct.type<@Component8A> +// CHECK-NEXT: struct.writem %[[VAL_20]][@fa] = %[[VAL_21]] : <@Main>, !struct.type<@Component8A> +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = function.call @Component8B::@compute(%[[VAL_21]]) : (!struct.type<@Component8A>) -> !struct.type<@Component8B> +// CHECK-NEXT: struct.writem %[[VAL_20]][@"fc:!s<@Component8C>+fb"] = %[[VAL_22]] : <@Main>, !struct.type<@Component8B> +// CHECK-NEXT: function.return %[[VAL_20]] : !struct.type<@Main> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_23:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_24:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_23]][@fa] : <@Main>, !struct.type<@Component8A> +// CHECK-NEXT: function.call @Component8A::@constrain(%[[VAL_24]]) : (!struct.type<@Component8A>) -> () +// CHECK-NEXT: %[[VAL_25:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_23]][@"fc:!s<@Component8C>+fb"] : <@Main>, !struct.type<@Component8B> +// CHECK-NEXT: function.call @Component8B::@constrain(%[[VAL_25]], %[[VAL_24]]) : (!struct.type<@Component8B>, !struct.type<@Component8A>) -> () +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // ----- @@ -961,49 +956,50 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @Component9C { -// CHECK-NEXT: struct.member @"fb:!s<@Component9B>+f" : !felt.type -// CHECK-NEXT: function.def @compute(%[[V0:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}, %[[V1:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}, %[[V2:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}) -> !struct.type<@Component9C> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V3:[0-9a-zA-Z_\.]+]] = struct.new : <@Component9C> -// CHECK-NEXT: %[[V4:[0-9a-zA-Z_\.]+]] = felt.mul %[[V2]], %[[V2]] : !felt.type, !felt.type -// CHECK-NEXT: struct.writem %[[V3]][@"fb:!s<@Component9B>+f"] = %[[V4]] : <@Component9C>, !felt.type -// CHECK-NEXT: function.return %[[V3]] : !struct.type<@Component9C> -// CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V5:[0-9a-zA-Z_\.]+]]: !struct.type<@Component9C>, %[[V6:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}, %[[V7:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}, %[[V8:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[V9:[0-9a-zA-Z_\.]+]] = felt.mul %[[V8]], %[[V8]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V10:[0-9a-zA-Z_\.]+]] = struct.readm %[[V5]][@"fb:!s<@Component9B>+f"] : <@Component9C>, !felt.type -// CHECK-NEXT: constrain.eq %[[V9]], %[[V10]] : !felt.type, !felt.type -// CHECK-NEXT: function.return -// CHECK-NEXT: } -// CHECK-NEXT: } -// -// CHECK-LABEL: struct.def @Main { -// CHECK-NEXT: struct.member @"fa:!s<@Component9A>+x" : !felt.type -// CHECK-NEXT: struct.member @"fa:!s<@Component9A>+y" : !felt.type -// CHECK-NEXT: struct.member @"fa:!s<@Component9A>+z" : !felt.type -// CHECK-NEXT: struct.member @fc : !struct.type<@Component9C> -// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V0:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> -// CHECK-NEXT: %[[V1:[0-9a-zA-Z_\.]+]] = felt.const 12 -// CHECK-NEXT: struct.writem %[[V0]][@"fa:!s<@Component9A>+x"] = %[[V1]] : <@Main>, !felt.type -// CHECK-NEXT: struct.writem %[[V0]][@"fa:!s<@Component9A>+y"] = %[[V1]] : <@Main>, !felt.type -// CHECK-NEXT: struct.writem %[[V0]][@"fa:!s<@Component9A>+z"] = %[[V1]] : <@Main>, !felt.type -// CHECK-NEXT: %[[V2:[0-9a-zA-Z_\.]+]] = struct.readm %[[V0]][@"fa:!s<@Component9A>+x"] : <@Main>, !felt.type -// CHECK-NEXT: %[[V3:[0-9a-zA-Z_\.]+]] = struct.readm %[[V0]][@"fa:!s<@Component9A>+y"] : <@Main>, !felt.type -// CHECK-NEXT: %[[V4:[0-9a-zA-Z_\.]+]] = struct.readm %[[V0]][@"fa:!s<@Component9A>+z"] : <@Main>, !felt.type -// CHECK-NEXT: %[[V5:[0-9a-zA-Z_\.]+]] = function.call @Component9C::@compute(%[[V2]], %[[V3]], %[[V4]]) : (!felt.type, !felt.type, !felt.type) -> !struct.type<@Component9C> -// CHECK-NEXT: struct.writem %[[V0]][@fc] = %[[V5]] : <@Main>, !struct.type<@Component9C> -// CHECK-NEXT: function.return %[[V0]] : !struct.type<@Main> +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Component9C { +// CHECK-NEXT: struct.member @"fb:!s<@Component9B>+f" : !felt.type +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}, %[[VAL_2:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}) -> !struct.type<@Component9C> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = struct.new : <@Component9C> +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_2]], %[[VAL_2]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_3]][@"fb:!s<@Component9B>+f"] = %[[VAL_4]] : <@Component9C>, !felt.type +// CHECK-NEXT: function.return %[[VAL_3]] : !struct.type<@Component9C> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_5:[0-9a-zA-Z_\.]+]]: !struct.type<@Component9C>, %[[VAL_6:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}, %[[VAL_7:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}, %[[VAL_8:[0-9a-zA-Z_\.]+]]: !felt.type {llzk.pub}) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_8]], %[[VAL_8]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_5]][@"fb:!s<@Component9B>+f"] : <@Component9C>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V6:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[V7:[0-9a-zA-Z_\.]+]] = felt.const 12 -// CHECK-NEXT: %[[V8:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@"fa:!s<@Component9A>+x"] : <@Main>, !felt.type -// CHECK-NEXT: constrain.eq %[[V8]], %[[V7]] : !felt.type, !felt.type -// CHECK-NEXT: %[[V9:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@fc] : <@Main>, !struct.type<@Component9C> -// CHECK-NEXT: %[[V10:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@"fa:!s<@Component9A>+x"] : <@Main>, !felt.type -// CHECK-NEXT: %[[V11:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@"fa:!s<@Component9A>+y"] : <@Main>, !felt.type -// CHECK-NEXT: %[[V12:[0-9a-zA-Z_\.]+]] = struct.readm %[[V6]][@"fa:!s<@Component9A>+z"] : <@Main>, !felt.type -// CHECK-NEXT: function.call @Component9C::@constrain(%[[V9]], %[[V10]], %[[V11]], %[[V12]]) : (!struct.type<@Component9C>, !felt.type, !felt.type, !felt.type) -> () -// CHECK-NEXT: function.return +// CHECK-NEXT: struct.def @Main { +// CHECK-NEXT: struct.member @"fa:!s<@Component9A>+x" : !felt.type {llzk.pub} +// CHECK-NEXT: struct.member @"fa:!s<@Component9A>+y" : !felt.type {llzk.pub} +// CHECK-NEXT: struct.member @"fa:!s<@Component9A>+z" : !felt.type {llzk.pub} +// CHECK-NEXT: struct.member @fc : !struct.type<@Component9C> +// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// CHECK-NEXT: %[[VAL_12:[0-9a-zA-Z_\.]+]] = felt.const 12 +// CHECK-NEXT: struct.writem %[[VAL_11]][@"fa:!s<@Component9A>+x"] = %[[VAL_12]] : <@Main>, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_11]][@"fa:!s<@Component9A>+y"] = %[[VAL_12]] : <@Main>, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_11]][@"fa:!s<@Component9A>+z"] = %[[VAL_12]] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_11]][@"fa:!s<@Component9A>+x"] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_11]][@"fa:!s<@Component9A>+y"] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_11]][@"fa:!s<@Component9A>+z"] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = function.call @Component9C::@compute(%[[VAL_13]], %[[VAL_14]], %[[VAL_15]]) : (!felt.type, !felt.type, !felt.type) -> !struct.type<@Component9C> +// CHECK-NEXT: struct.writem %[[VAL_11]][@fc] = %[[VAL_16]] : <@Main>, !struct.type<@Component9C> +// CHECK-NEXT: function.return %[[VAL_11]] : !struct.type<@Main> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_17:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = felt.const 12 +// CHECK-NEXT: %[[VAL_19:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@"fa:!s<@Component9A>+x"] : <@Main>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_19]], %[[VAL_18]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@fc] : <@Main>, !struct.type<@Component9C> +// CHECK-NEXT: %[[VAL_21:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@"fa:!s<@Component9A>+x"] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@"fa:!s<@Component9A>+y"] : <@Main>, !felt.type +// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@"fa:!s<@Component9A>+z"] : <@Main>, !felt.type +// CHECK-NEXT: function.call @Component9C::@constrain(%[[VAL_20]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]]) : (!struct.type<@Component9C>, !felt.type, !felt.type, !felt.type) -> () +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/Transforms/InlineStructs/inline_structs_pass_2.llzk b/test/Transforms/InlineStructs/inline_structs_pass_2.llzk index f437eacadb..315a94e521 100644 --- a/test/Transforms/InlineStructs/inline_structs_pass_2.llzk +++ b/test/Transforms/InlineStructs/inline_structs_pass_2.llzk @@ -1,5 +1,4 @@ -// RUN: llzk-opt -split-input-file --pass-pipeline='builtin.module(llzk-flatten{cleanup=main-as-root},llzk-inline-structs)' -verify-diagnostics %s | FileCheck %s -// COM: This could be merged with `inline_structs_pass.llzk` once the issue with the default cleanup method of `llzk-flatten` is fixed (LLZK-303). +// RUN: llzk-opt -split-input-file --pass-pipeline='builtin.module(llzk-flatten{cleanup=main-as-root},llzk-inline-structs,llzk-unused-declaration-elim{remove-structs})' -verify-diagnostics %s | FileCheck %s // TESTS: Inlining after `llzk-flatten` pass removes struct parameters module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { @@ -54,13 +53,15 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @Main { -// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V0:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> -// CHECK-NEXT: function.return %[[V0]] : !struct.type<@Main> -// CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V1:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// CHECK-NEXT: function.return +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// CHECK-NEXT: struct.def @Main { +// CHECK-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// CHECK-NEXT: function.return %[[VAL_0]] : !struct.type<@Main> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: } // ----- @@ -137,45 +138,48 @@ module attributes {llzk.main = !struct.type<@TMain::@Main>, llzk.lang} { } } } -// CHECK-LABEL: struct.def @TPair_f_77_Pair { -// CHECK-NEXT: function.def @compute() -> !struct.type<@TPair_f_77_Pair> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V0:[0-9a-zA-Z_\.]+]] = struct.new : <@TPair_f_77_Pair> -// CHECK-NEXT: function.return %[[V0]] : !struct.type<@TPair_f_77_Pair> -// CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V1:[0-9a-zA-Z_\.]+]]: !struct.type<@TPair_f_77_Pair>) attributes {function.allow_constraint} { -// CHECK-NEXT: function.return -// CHECK-NEXT: } -// CHECK-NEXT: } -// -// CHECK-LABEL: struct.def @Main { -// CHECK-NEXT: struct.member @"sub:!s<@TMakeGuess_77_MakeGuess>+dat" : !array.type<4 x !struct.type<@TPair_f_77_Pair>> -// CHECK-NEXT: function.def @compute() -> !struct.type<@TMain::@Main> attributes {function.allow_witness} { -// CHECK-NEXT: %[[V0:[0-9a-zA-Z_\.]+]] = struct.new : <@TMain::@Main> -// CHECK-NEXT: %[[V1:[0-9a-zA-Z_\.]+]] = arith.constant 3 : index -// CHECK-NEXT: %[[V2:[0-9a-zA-Z_\.]+]] = arith.constant 2 : index -// CHECK-NEXT: %[[V3:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: %[[V4:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index -// CHECK-NEXT: %[[V5:[0-9a-zA-Z_\.]+]] = array.new : <4 x !struct.type<@TPair_f_77_Pair>> -// CHECK-NEXT: %[[V6:[0-9a-zA-Z_\.]+]] = function.call @TPair_f_77_Pair::@compute() : () -> !struct.type<@TPair_f_77_Pair> -// CHECK-NEXT: array.write %[[V5]]{{\[}}%[[V4]]] = %[[V6]] : <4 x !struct.type<@TPair_f_77_Pair>>, !struct.type<@TPair_f_77_Pair> -// CHECK-NEXT: %[[V7:[0-9a-zA-Z_\.]+]] = function.call @TPair_f_77_Pair::@compute() : () -> !struct.type<@TPair_f_77_Pair> -// CHECK-NEXT: array.write %[[V5]]{{\[}}%[[V3]]] = %[[V7]] : <4 x !struct.type<@TPair_f_77_Pair>>, !struct.type<@TPair_f_77_Pair> -// CHECK-NEXT: %[[V8:[0-9a-zA-Z_\.]+]] = function.call @TPair_f_77_Pair::@compute() : () -> !struct.type<@TPair_f_77_Pair> -// CHECK-NEXT: array.write %[[V5]]{{\[}}%[[V2]]] = %[[V8]] : <4 x !struct.type<@TPair_f_77_Pair>>, !struct.type<@TPair_f_77_Pair> -// CHECK-NEXT: %[[V9:[0-9a-zA-Z_\.]+]] = function.call @TPair_f_77_Pair::@compute() : () -> !struct.type<@TPair_f_77_Pair> -// CHECK-NEXT: array.write %[[V5]]{{\[}}%[[V1]]] = %[[V9]] : <4 x !struct.type<@TPair_f_77_Pair>>, !struct.type<@TPair_f_77_Pair> -// CHECK-NEXT: struct.writem %[[V0]][@"sub:!s<@TMakeGuess_77_MakeGuess>+dat"] = %[[V5]] : <@TMain::@Main>, !array.type<4 x !struct.type<@TPair_f_77_Pair>> -// CHECK-NEXT: function.return %[[V0]] : !struct.type<@TMain::@Main> +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@TMain::@Main>} { +// CHECK-NEXT: struct.def @TPair_f_77_Pair { +// CHECK-NEXT: function.def @compute() -> !struct.type<@TPair_f_77_Pair> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@TPair_f_77_Pair> +// CHECK-NEXT: function.return %[[VAL_0]] : !struct.type<@TPair_f_77_Pair> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@TPair_f_77_Pair>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: function.def @constrain(%[[V10:[0-9a-zA-Z_\.]+]]: !struct.type<@TMain::@Main>) attributes {function.allow_constraint} { -// CHECK-NEXT: %[[V11:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index -// CHECK-NEXT: %[[V12:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index -// CHECK-NEXT: %[[V13:[0-9a-zA-Z_\.]+]] = struct.readm %[[V10]][@"sub:!s<@TMakeGuess_77_MakeGuess>+dat"] : <@TMain::@Main>, !array.type<4 x !struct.type<@TPair_f_77_Pair>> -// CHECK-NEXT: %[[V14:[0-9a-zA-Z_\.]+]] = array.len %[[V13]], %[[V11]] : <4 x !struct.type<@TPair_f_77_Pair>> -// CHECK-NEXT: scf.for %[[V15:[0-9a-zA-Z_\.]+]] = %[[V11]] to %[[V14]] step %[[V12]] { -// CHECK-NEXT: %[[V16:[0-9a-zA-Z_\.]+]] = array.read %[[V13]]{{\[}}%[[V15]]] : <4 x !struct.type<@TPair_f_77_Pair>>, !struct.type<@TPair_f_77_Pair> -// CHECK-NEXT: function.call @TPair_f_77_Pair::@constrain(%[[V16]]) : (!struct.type<@TPair_f_77_Pair>) -> () +// CHECK-NEXT: module @TMain { +// CHECK-NEXT: struct.def @Main { +// CHECK-NEXT: struct.member @"sub:!s<@TMakeGuess_77_MakeGuess>+dat" : !array.type<4 x !struct.type<@TPair_f_77_Pair>> +// CHECK-NEXT: function.def @compute() -> !struct.type<@TMain::@Main> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@TMain::@Main> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = arith.constant 3 : index +// CHECK-NEXT: %[[VAL_4:[0-9a-zA-Z_\.]+]] = arith.constant 2 : index +// CHECK-NEXT: %[[VAL_5:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = array.new : <4 x !struct.type<@TPair_f_77_Pair>> +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = function.call @TPair_f_77_Pair::@compute() : () -> !struct.type<@TPair_f_77_Pair> +// CHECK-NEXT: array.write %[[VAL_7]]{{\[}}%[[VAL_6]]] = %[[VAL_8]] : <4 x !struct.type<@TPair_f_77_Pair>>, !struct.type<@TPair_f_77_Pair> +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = function.call @TPair_f_77_Pair::@compute() : () -> !struct.type<@TPair_f_77_Pair> +// CHECK-NEXT: array.write %[[VAL_7]]{{\[}}%[[VAL_5]]] = %[[VAL_9]] : <4 x !struct.type<@TPair_f_77_Pair>>, !struct.type<@TPair_f_77_Pair> +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = function.call @TPair_f_77_Pair::@compute() : () -> !struct.type<@TPair_f_77_Pair> +// CHECK-NEXT: array.write %[[VAL_7]]{{\[}}%[[VAL_4]]] = %[[VAL_10]] : <4 x !struct.type<@TPair_f_77_Pair>>, !struct.type<@TPair_f_77_Pair> +// CHECK-NEXT: %[[VAL_11:[0-9a-zA-Z_\.]+]] = function.call @TPair_f_77_Pair::@compute() : () -> !struct.type<@TPair_f_77_Pair> +// CHECK-NEXT: array.write %[[VAL_7]]{{\[}}%[[VAL_3]]] = %[[VAL_11]] : <4 x !struct.type<@TPair_f_77_Pair>>, !struct.type<@TPair_f_77_Pair> +// CHECK-NEXT: struct.writem %[[VAL_2]][@"sub:!s<@TMakeGuess_77_MakeGuess>+dat"] = %[[VAL_7]] : <@TMain::@Main>, !array.type<4 x !struct.type<@TPair_f_77_Pair>> +// CHECK-NEXT: function.return %[[VAL_2]] : !struct.type<@TMain::@Main> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_12:[0-9a-zA-Z_\.]+]]: !struct.type<@TMain::@Main>) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_12]][@"sub:!s<@TMakeGuess_77_MakeGuess>+dat"] : <@TMain::@Main>, !array.type<4 x !struct.type<@TPair_f_77_Pair>> +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = array.len %[[VAL_15]], %[[VAL_13]] : <4 x !struct.type<@TPair_f_77_Pair>> +// CHECK-NEXT: scf.for %[[VAL_17:[0-9a-zA-Z_\.]+]] = %[[VAL_13]] to %[[VAL_16]] step %[[VAL_14]] { +// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = array.read %[[VAL_15]]{{\[}}%[[VAL_17]]] : <4 x !struct.type<@TPair_f_77_Pair>>, !struct.type<@TPair_f_77_Pair> +// CHECK-NEXT: function.call @TPair_f_77_Pair::@constrain(%[[VAL_18]]) : (!struct.type<@TPair_f_77_Pair>) -> () +// CHECK-NEXT: } +// CHECK-NEXT: function.return +// CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: function.return // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/test/Transforms/PolyLowering/poly_lowering_fail_low_deg.llzk b/test/Transforms/PolyLowering/poly_lowering_fail_low_deg.llzk index 644f6727c0..e76379c46d 100644 --- a/test/Transforms/PolyLowering/poly_lowering_fail_low_deg.llzk +++ b/test/Transforms/PolyLowering/poly_lowering_fail_low_deg.llzk @@ -1,4 +1,4 @@ -// RUN: llzk-opt -I %S -split-input-file -llzk-full-poly-lowering="max-degree=1" -verify-diagnostics %s +// RUN: llzk-opt -I %S -split-input-file -llzk-poly-lowering-pass="max-degree=1" -verify-diagnostics %s // expected-error@+1 {{Invalid max degree: 1. Must be >= 2.}} module attributes {llzk.lang} { diff --git a/test/Transforms/PolyLowering/poly_lowering_fail_reserved_name.llzk b/test/Transforms/PolyLowering/poly_lowering_fail_reserved_name.llzk index b88ad600fb..92a28598ed 100644 --- a/test/Transforms/PolyLowering/poly_lowering_fail_reserved_name.llzk +++ b/test/Transforms/PolyLowering/poly_lowering_fail_reserved_name.llzk @@ -1,4 +1,4 @@ -// RUN: llzk-opt -I %S -split-input-file -llzk-full-poly-lowering="max-degree=2" -verify-diagnostics %s +// RUN: llzk-opt -I %S -split-input-file -llzk-poly-lowering-pass="max-degree=2" -verify-diagnostics %s module attributes {llzk.lang} { struct.def @CmpConstraint { diff --git a/test/Transforms/PolyLowering/poly_lowering_pass_deg2.llzk b/test/Transforms/PolyLowering/poly_lowering_pass_deg2.llzk index 0e3a18a727..a691193de3 100644 --- a/test/Transforms/PolyLowering/poly_lowering_pass_deg2.llzk +++ b/test/Transforms/PolyLowering/poly_lowering_pass_deg2.llzk @@ -1,6 +1,6 @@ -// RUN: llzk-opt -I %S -split-input-file -llzk-full-poly-lowering -verify-diagnostics %s | FileCheck --enable-var-scope %s +// RUN: llzk-opt -split-input-file -llzk-full-poly-lowering='flatten-inline={inlining={max-merge-complexity=1}}' -verify-diagnostics %s | FileCheck --enable-var-scope %s -module attributes {llzk.lang} { +module attributes {llzk.lang, llzk.main = !struct.type<@CmpConstraint>} { // lowers constraints to be at most degree 2 polynomials struct.def @CmpConstraint { function.def @compute(%a: !felt.type, %b: !felt.type) -> !struct.type<@CmpConstraint> { @@ -19,29 +19,30 @@ module attributes {llzk.lang} { } } } - -// CHECK-LABEL: struct.def @CmpConstraint { -// CHECK: function.def @compute(%[[VAL_0:.*]]: !felt.type, %[[VAL_1:.*]]: !felt.type) -> !struct.type<@CmpConstraint> attributes {function.allow_witness} { -// CHECK: %[[VAL_2:.*]] = struct.new : <@CmpConstraint> -// CHECK: %[[VAL_3:.*]] = felt.mul %[[VAL_0]], %[[VAL_1]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_0] = %[[VAL_3]] : <@CmpConstraint>, !felt.type -// CHECK: function.return %[[VAL_2]] : !struct.type<@CmpConstraint> -// CHECK: } -// CHECK: function.def @constrain(%[[VAL_4:.*]]: !struct.type<@CmpConstraint>, %[[VAL_5:.*]]: !felt.type, %[[VAL_6:.*]]: !felt.type) attributes {function.allow_constraint} { -// CHECK: %[[VAL_7:.*]] = felt.mul %[[VAL_5]], %[[VAL_6]] : !felt.type, !felt.type -// CHECK: %[[VAL_8:.*]] = struct.readm %[[VAL_4]][@__llzk_poly_lowering_pass_aux_member_0] : <@CmpConstraint>, !felt.type -// CHECK: constrain.eq %[[VAL_8]], %[[VAL_7]] : !felt.type, !felt.type -// CHECK: %[[VAL_9:.*]] = felt.mul %[[VAL_8]], %[[VAL_8]] : !felt.type, !felt.type -// CHECK: %[[VAL_10:.*]] = felt.mul %[[VAL_8]], %[[VAL_5]] : !felt.type, !felt.type -// CHECK: constrain.eq %[[VAL_9]], %[[VAL_8]] : !felt.type, !felt.type -// CHECK: constrain.eq %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type -// CHECK: function.return -// CHECK: } -// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_0 : !felt.type -// CHECK: } +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@CmpConstraint>} { +// CHECK-NEXT: struct.def @CmpConstraint { +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@CmpConstraint> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@CmpConstraint> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_0]], %[[VAL_1]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_0] = %[[VAL_3]] : <@CmpConstraint>, !felt.type +// CHECK-NEXT: function.return %[[VAL_2]] : !struct.type<@CmpConstraint> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !struct.type<@CmpConstraint>, %[[VAL_5:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_6:[0-9a-zA-Z_\.]+]]: !felt.type) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_5]], %[[VAL_6]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_4]][@__llzk_poly_lowering_pass_aux_member_0] : <@CmpConstraint>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_8]], %[[VAL_7]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_8]], %[[VAL_8]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_8]], %[[VAL_5]] : !felt.type, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_9]], %[[VAL_8]] : !felt.type, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: struct.member @__llzk_poly_lowering_pass_aux_member_0 : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: } // ----- -module attributes {llzk.lang} { +module attributes {llzk.lang, llzk.main = !struct.type<@Mod1>} { struct.def @Mod2 { function.def @compute(%a: !felt.type, %b: !felt.type) -> !struct.type<@Mod2> { %self = struct.new : !struct.type<@Mod2> @@ -77,60 +78,60 @@ module attributes {llzk.lang} { } } } - -// CHECK-LABEL: struct.def @Mod2 { -// CHECK: function.def @compute(%[[VAL_0:.*]]: !felt.type, %[[VAL_1:.*]]: !felt.type) -> !struct.type<@Mod2> attributes {function.allow_witness} { -// CHECK: %[[VAL_2:.*]] = struct.new : <@Mod2> -// CHECK: %[[VAL_3:.*]] = felt.mul %[[VAL_0]], %[[VAL_1]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_0] = %[[VAL_3]] : <@Mod2>, !felt.type -// CHECK: function.return %[[VAL_2]] : !struct.type<@Mod2> -// CHECK: } -// CHECK: function.def @constrain(%[[VAL_4:.*]]: !struct.type<@Mod2>, %[[VAL_5:.*]]: !felt.type, %[[VAL_6:.*]]: !felt.type) attributes {function.allow_constraint} { -// CHECK: %[[VAL_7:.*]] = felt.mul %[[VAL_5]], %[[VAL_6]] : !felt.type, !felt.type -// CHECK: %[[VAL_8:.*]] = struct.readm %[[VAL_4]][@__llzk_poly_lowering_pass_aux_member_0] : <@Mod2>, !felt.type -// CHECK: constrain.eq %[[VAL_8]], %[[VAL_7]] : !felt.type, !felt.type -// CHECK: %[[VAL_9:.*]] = felt.mul %[[VAL_8]], %[[VAL_8]] : !felt.type, !felt.type -// CHECK: %[[VAL_10:.*]] = felt.mul %[[VAL_8]], %[[VAL_5]] : !felt.type, !felt.type -// CHECK: constrain.eq %[[VAL_9]], %[[VAL_8]] : !felt.type, !felt.type -// CHECK: constrain.eq %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type -// CHECK: function.return -// CHECK: } -// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_0 : !felt.type -// CHECK: } - -// CHECK-LABEL: struct.def @Mod1 { -// CHECK: struct.member @mod2 : !struct.type<@Mod2> -// CHECK: function.def @compute(%[[VAL_0:.*]]: !felt.type, %[[VAL_1:.*]]: !felt.type) -> !struct.type<@Mod1> attributes {function.allow_witness} { -// CHECK: %[[VAL_2:.*]] = struct.new : <@Mod1> -// CHECK: %[[VAL_3:.*]] = felt.mul %[[VAL_0]], %[[VAL_1]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_1] = %[[VAL_3]] : <@Mod1>, !felt.type -// CHECK: %[[VAL_5:.*]] = felt.mul %[[VAL_3]], %[[VAL_3]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_2] = %[[VAL_5]] : <@Mod1>, !felt.type -// CHECK: %[[VAL_6:.*]] = felt.mul %[[VAL_3]], %[[VAL_0]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_3] = %[[VAL_6]] : <@Mod1>, !felt.type -// CHECK: function.return %[[VAL_2]] : !struct.type<@Mod1> -// CHECK: } -// CHECK: function.def @constrain(%[[VAL_7:.*]]: !struct.type<@Mod1>, %[[VAL_8:.*]]: !felt.type, %[[VAL_9:.*]]: !felt.type) attributes {function.allow_constraint} { -// CHECK: %[[VAL_10:.*]] = felt.mul %[[VAL_8]], %[[VAL_9]] : !felt.type, !felt.type -// CHECK: %[[VAL_11:.*]] = struct.readm %[[VAL_7]][@__llzk_poly_lowering_pass_aux_member_1] : <@Mod1>, !felt.type -// CHECK: constrain.eq %[[VAL_11]], %[[VAL_10]] : !felt.type, !felt.type -// CHECK: %[[VAL_12:.*]] = felt.mul %[[VAL_11]], %[[VAL_11]] : !felt.type, !felt.type -// CHECK: %[[VAL_13:.*]] = felt.mul %[[VAL_11]], %[[VAL_8]] : !felt.type, !felt.type -// CHECK: %[[VAL_14:.*]] = struct.readm %[[VAL_7]][@mod2] : <@Mod1>, !struct.type<@Mod2> -// CHECK: %[[VAL_15:.*]] = struct.readm %[[VAL_7]][@__llzk_poly_lowering_pass_aux_member_2] : <@Mod1>, !felt.type -// CHECK: constrain.eq %[[VAL_15]], %[[VAL_12]] : !felt.type, !felt.type -// CHECK: %[[VAL_16:.*]] = struct.readm %[[VAL_7]][@__llzk_poly_lowering_pass_aux_member_3] : <@Mod1>, !felt.type -// CHECK: constrain.eq %[[VAL_16]], %[[VAL_13]] : !felt.type, !felt.type -// CHECK: function.call @Mod2::@constrain(%[[VAL_14]], %[[VAL_15]], %[[VAL_16]]) : (!struct.type<@Mod2>, !felt.type, !felt.type) -> () -// CHECK: function.return -// CHECK: } -// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_1 : !felt.type -// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_2 : !felt.type -// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_3 : !felt.type -// CHECK: } +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Mod1>} { +// CHECK-NEXT: struct.def @Mod2 { +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@Mod2> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@Mod2> +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_0]], %[[VAL_1]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_0] = %[[VAL_3]] : <@Mod2>, !felt.type +// CHECK-NEXT: function.return %[[VAL_2]] : !struct.type<@Mod2> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !struct.type<@Mod2>, %[[VAL_5:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_6:[0-9a-zA-Z_\.]+]]: !felt.type) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_5]], %[[VAL_6]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_4]][@__llzk_poly_lowering_pass_aux_member_0] : <@Mod2>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_8]], %[[VAL_7]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_8]], %[[VAL_8]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_8]], %[[VAL_5]] : !felt.type, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_9]], %[[VAL_8]] : !felt.type, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: struct.member @__llzk_poly_lowering_pass_aux_member_0 : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: struct.def @Mod1 { +// CHECK-NEXT: struct.member @mod2 : !struct.type<@Mod2> +// CHECK-NEXT: function.def @compute(%[[VAL_11:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_12:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@Mod1> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = struct.new : <@Mod1> +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_11]], %[[VAL_12]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_13]][@__llzk_poly_lowering_pass_aux_member_1] = %[[VAL_14]] : <@Mod1>, !felt.type +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_14]], %[[VAL_14]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_13]][@__llzk_poly_lowering_pass_aux_member_2] = %[[VAL_15]] : <@Mod1>, !felt.type +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_14]], %[[VAL_11]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_13]][@__llzk_poly_lowering_pass_aux_member_3] = %[[VAL_16]] : <@Mod1>, !felt.type +// CHECK-NEXT: function.return %[[VAL_13]] : !struct.type<@Mod1> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_17:[0-9a-zA-Z_\.]+]]: !struct.type<@Mod1>, %[[VAL_18:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_19:[0-9a-zA-Z_\.]+]]: !felt.type) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_20:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_18]], %[[VAL_19]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_21:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@__llzk_poly_lowering_pass_aux_member_1] : <@Mod1>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_21]], %[[VAL_20]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_21]], %[[VAL_21]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_21]], %[[VAL_18]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_24:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@mod2] : <@Mod1>, !struct.type<@Mod2> +// CHECK-NEXT: %[[VAL_25:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@__llzk_poly_lowering_pass_aux_member_2] : <@Mod1>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_25]], %[[VAL_22]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_26:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_17]][@__llzk_poly_lowering_pass_aux_member_3] : <@Mod1>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_26]], %[[VAL_23]] : !felt.type, !felt.type +// CHECK-NEXT: function.call @Mod2::@constrain(%[[VAL_24]], %[[VAL_25]], %[[VAL_26]]) : (!struct.type<@Mod2>, !felt.type, !felt.type) -> () +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: struct.member @__llzk_poly_lowering_pass_aux_member_1 : !felt.type +// CHECK-NEXT: struct.member @__llzk_poly_lowering_pass_aux_member_2 : !felt.type +// CHECK-NEXT: struct.member @__llzk_poly_lowering_pass_aux_member_3 : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: } // ----- -module attributes {llzk.lang} { +module attributes {llzk.lang, llzk.main = !struct.type<@Mod1>} { struct.def @Mod2 { struct.member @val: !felt.type {llzk.pub} function.def @compute(%a: !felt.type, %b: !felt.type) -> !struct.type<@Mod2> { @@ -175,68 +176,68 @@ module attributes {llzk.lang} { } } } - -// CHECK-LABEL: struct.def @Mod2 { -// CHECK: struct.member @val : !felt.type -// CHECK: function.def @compute(%[[VAL_0:.*]]: !felt.type, %[[VAL_1:.*]]: !felt.type) -> !struct.type<@Mod2> attributes {function.allow_witness} { -// CHECK: %[[VAL_2:.*]] = struct.new : <@Mod2> -// CHECK: struct.writem %[[VAL_2]][@val] = %[[VAL_0]] : <@Mod2>, !felt.type -// CHECK: %[[VAL_3:.*]] = felt.mul %[[VAL_0]], %[[VAL_1]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_0] = %[[VAL_3]] : <@Mod2>, !felt.type -// CHECK: function.return %[[VAL_2]] : !struct.type<@Mod2> -// CHECK: } -// CHECK: function.def @constrain(%[[VAL_4:.*]]: !struct.type<@Mod2>, %[[VAL_5:.*]]: !felt.type, %[[VAL_6:.*]]: !felt.type) attributes {function.allow_constraint} { -// CHECK: %[[VAL_7:.*]] = felt.mul %[[VAL_5]], %[[VAL_6]] : !felt.type, !felt.type -// CHECK: %[[VAL_8:.*]] = struct.readm %[[VAL_4]][@__llzk_poly_lowering_pass_aux_member_0] : <@Mod2>, !felt.type -// CHECK: constrain.eq %[[VAL_8]], %[[VAL_7]] : !felt.type, !felt.type -// CHECK: %[[VAL_9:.*]] = felt.mul %[[VAL_8]], %[[VAL_8]] : !felt.type, !felt.type -// CHECK: %[[VAL_10:.*]] = felt.mul %[[VAL_8]], %[[VAL_5]] : !felt.type, !felt.type -// CHECK: constrain.eq %[[VAL_9]], %[[VAL_8]] : !felt.type, !felt.type -// CHECK: constrain.eq %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type -// CHECK: function.return -// CHECK: } -// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_0 : !felt.type -// CHECK: } - -// CHECK-LABEL: struct.def @Mod1 { -// CHECK: struct.member @mod2 : !struct.type<@Mod2> -// CHECK: struct.member @val : !felt.type -// CHECK: function.def @compute(%[[VAL_0:.*]]: !felt.type, %[[VAL_1:.*]]: !felt.type) -> !struct.type<@Mod1> attributes {function.allow_witness} { -// CHECK: %[[VAL_2:.*]] = struct.new : <@Mod1> -// CHECK: %[[VAL_3:.*]] = struct.readm %[[VAL_2]][@val] : <@Mod1>, !felt.type -// CHECK: %[[VAL_4:.*]] = felt.mul %[[VAL_3]], %[[VAL_3]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_1] = %[[VAL_4]] : <@Mod1>, !felt.type -// CHECK: %[[VAL_5:.*]] = felt.mul %[[VAL_0]], %[[VAL_1]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_2] = %[[VAL_5]] : <@Mod1>, !felt.type -// CHECK: %[[VAL_7:.*]] = felt.mul %[[VAL_5]], %[[VAL_5]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_3] = %[[VAL_7]] : <@Mod1>, !felt.type -// CHECK: %[[VAL_8:.*]] = felt.mul %[[VAL_5]], %[[VAL_0]] : !felt.type, !felt.type -// CHECK: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_4] = %[[VAL_8]] : <@Mod1>, !felt.type -// CHECK: function.return %[[VAL_2]] : !struct.type<@Mod1> -// CHECK: } -// CHECK: function.def @constrain(%[[VAL_9:.*]]: !struct.type<@Mod1>, %[[VAL_10:.*]]: !felt.type, %[[VAL_11:.*]]: !felt.type) attributes {function.allow_constraint} { -// CHECK: %[[VAL_12:.*]] = felt.mul %[[VAL_10]], %[[VAL_11]] : !felt.type, !felt.type -// CHECK: %[[VAL_13:.*]] = struct.readm %[[VAL_9]][@__llzk_poly_lowering_pass_aux_member_2] : <@Mod1>, !felt.type -// CHECK: constrain.eq %[[VAL_13]], %[[VAL_12]] : !felt.type, !felt.type -// CHECK: %[[VAL_14:.*]] = felt.mul %[[VAL_13]], %[[VAL_13]] : !felt.type, !felt.type -// CHECK: %[[VAL_15:.*]] = felt.mul %[[VAL_13]], %[[VAL_10]] : !felt.type, !felt.type -// CHECK: %[[VAL_16:.*]] = struct.readm %[[VAL_9]][@mod2] : <@Mod1>, !struct.type<@Mod2> -// CHECK: %[[VAL_17:.*]] = struct.readm %[[VAL_9]][@__llzk_poly_lowering_pass_aux_member_3] : <@Mod1>, !felt.type -// CHECK: constrain.eq %[[VAL_17]], %[[VAL_14]] : !felt.type, !felt.type -// CHECK: %[[VAL_18:.*]] = struct.readm %[[VAL_9]][@__llzk_poly_lowering_pass_aux_member_4] : <@Mod1>, !felt.type -// CHECK: constrain.eq %[[VAL_18]], %[[VAL_15]] : !felt.type, !felt.type -// CHECK: function.call @Mod2::@constrain(%[[VAL_16]], %[[VAL_17]], %[[VAL_18]]) : (!struct.type<@Mod2>, !felt.type, !felt.type) -> () -// CHECK: %[[VAL_19:.*]] = struct.readm %[[VAL_16]][@val] : <@Mod2>, !felt.type -// CHECK: %[[VAL_20:.*]] = felt.mul %[[VAL_19]], %[[VAL_19]] : !felt.type, !felt.type -// CHECK: %[[VAL_21:.*]] = struct.readm %[[VAL_9]][@__llzk_poly_lowering_pass_aux_member_1] : <@Mod1>, !felt.type -// CHECK: constrain.eq %[[VAL_21]], %[[VAL_20]] : !felt.type, !felt.type -// CHECK: %[[VAL_22:.*]] = felt.mul %[[VAL_21]], %[[VAL_21]] : !felt.type, !felt.type -// CHECK: %[[VAL_23:.*]] = struct.readm %[[VAL_9]][@val] : <@Mod1>, !felt.type -// CHECK: constrain.eq %[[VAL_23]], %[[VAL_22]] : !felt.type, !felt.type -// CHECK: function.return -// CHECK: } -// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_1 : !felt.type -// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_2 : !felt.type -// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_3 : !felt.type -// CHECK: struct.member @__llzk_poly_lowering_pass_aux_member_4 : !felt.type -// CHECK: } +// CHECK-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Mod1>} { +// CHECK-NEXT: struct.def @Mod2 { +// CHECK-NEXT: struct.member @val : !felt.type {llzk.pub} +// CHECK-NEXT: function.def @compute(%[[VAL_0:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_1:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@Mod2> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@Mod2> +// CHECK-NEXT: struct.writem %[[VAL_2]][@val] = %[[VAL_0]] : <@Mod2>, !felt.type +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_0]], %[[VAL_1]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_2]][@__llzk_poly_lowering_pass_aux_member_0] = %[[VAL_3]] : <@Mod2>, !felt.type +// CHECK-NEXT: function.return %[[VAL_2]] : !struct.type<@Mod2> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !struct.type<@Mod2>, %[[VAL_5:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_6:[0-9a-zA-Z_\.]+]]: !felt.type) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_5]], %[[VAL_6]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_4]][@__llzk_poly_lowering_pass_aux_member_0] : <@Mod2>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_8]], %[[VAL_7]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_8]], %[[VAL_8]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_10:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_8]], %[[VAL_5]] : !felt.type, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_9]], %[[VAL_8]] : !felt.type, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_9]], %[[VAL_10]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: struct.member @__llzk_poly_lowering_pass_aux_member_0 : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: struct.def @Mod1 { +// CHECK-NEXT: struct.member @mod2 : !struct.type<@Mod2> +// CHECK-NEXT: struct.member @val : !felt.type +// CHECK-NEXT: function.def @compute(%[[VAL_11:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_12:[0-9a-zA-Z_\.]+]]: !felt.type) -> !struct.type<@Mod1> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_13:[0-9a-zA-Z_\.]+]] = struct.new : <@Mod1> +// CHECK-NEXT: %[[VAL_14:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_13]][@val] : <@Mod1>, !felt.type +// CHECK-NEXT: %[[VAL_15:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_14]], %[[VAL_14]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_13]][@__llzk_poly_lowering_pass_aux_member_1] = %[[VAL_15]] : <@Mod1>, !felt.type +// CHECK-NEXT: %[[VAL_16:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_11]], %[[VAL_12]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_13]][@__llzk_poly_lowering_pass_aux_member_2] = %[[VAL_16]] : <@Mod1>, !felt.type +// CHECK-NEXT: %[[VAL_17:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_16]], %[[VAL_16]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_13]][@__llzk_poly_lowering_pass_aux_member_3] = %[[VAL_17]] : <@Mod1>, !felt.type +// CHECK-NEXT: %[[VAL_18:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_16]], %[[VAL_11]] : !felt.type, !felt.type +// CHECK-NEXT: struct.writem %[[VAL_13]][@__llzk_poly_lowering_pass_aux_member_4] = %[[VAL_18]] : <@Mod1>, !felt.type +// CHECK-NEXT: function.return %[[VAL_13]] : !struct.type<@Mod1> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_19:[0-9a-zA-Z_\.]+]]: !struct.type<@Mod1>, %[[VAL_20:[0-9a-zA-Z_\.]+]]: !felt.type, %[[VAL_21:[0-9a-zA-Z_\.]+]]: !felt.type) attributes {function.allow_constraint} { +// CHECK-NEXT: %[[VAL_22:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_20]], %[[VAL_21]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_23:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_19]][@__llzk_poly_lowering_pass_aux_member_2] : <@Mod1>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_23]], %[[VAL_22]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_24:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_23]], %[[VAL_23]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_25:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_23]], %[[VAL_20]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_26:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_19]][@mod2] : <@Mod1>, !struct.type<@Mod2> +// CHECK-NEXT: %[[VAL_27:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_19]][@__llzk_poly_lowering_pass_aux_member_3] : <@Mod1>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_27]], %[[VAL_24]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_28:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_19]][@__llzk_poly_lowering_pass_aux_member_4] : <@Mod1>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_28]], %[[VAL_25]] : !felt.type, !felt.type +// CHECK-NEXT: function.call @Mod2::@constrain(%[[VAL_26]], %[[VAL_27]], %[[VAL_28]]) : (!struct.type<@Mod2>, !felt.type, !felt.type) -> () +// CHECK-NEXT: %[[VAL_29:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_26]][@val] : <@Mod2>, !felt.type +// CHECK-NEXT: %[[VAL_30:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_29]], %[[VAL_29]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_31:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_19]][@__llzk_poly_lowering_pass_aux_member_1] : <@Mod1>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_31]], %[[VAL_30]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_32:[0-9a-zA-Z_\.]+]] = felt.mul %[[VAL_31]], %[[VAL_31]] : !felt.type, !felt.type +// CHECK-NEXT: %[[VAL_33:[0-9a-zA-Z_\.]+]] = struct.readm %[[VAL_19]][@val] : <@Mod1>, !felt.type +// CHECK-NEXT: constrain.eq %[[VAL_33]], %[[VAL_32]] : !felt.type, !felt.type +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: struct.member @__llzk_poly_lowering_pass_aux_member_1 : !felt.type +// CHECK-NEXT: struct.member @__llzk_poly_lowering_pass_aux_member_2 : !felt.type +// CHECK-NEXT: struct.member @__llzk_poly_lowering_pass_aux_member_3 : !felt.type +// CHECK-NEXT: struct.member @__llzk_poly_lowering_pass_aux_member_4 : !felt.type +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Transforms/PolyLowering/poly_lowering_pass_deg3.llzk b/test/Transforms/PolyLowering/poly_lowering_pass_deg3.llzk index d43c2e1241..e4486b2a02 100644 --- a/test/Transforms/PolyLowering/poly_lowering_pass_deg3.llzk +++ b/test/Transforms/PolyLowering/poly_lowering_pass_deg3.llzk @@ -1,4 +1,4 @@ -// RUN: llzk-opt -I %S -split-input-file -llzk-full-poly-lowering="max-degree=3" -verify-diagnostics %s | FileCheck --enable-var-scope %s +// RUN: llzk-opt -I %S -split-input-file -llzk-poly-lowering-pass="max-degree=3" -verify-diagnostics %s | FileCheck --enable-var-scope %s module attributes {llzk.lang} { // lowers constraints to be at most degree 3 polynomials diff --git a/test/Transforms/R1CSLowering/r1cs_lowering_pass.llzk b/test/Transforms/R1CSLowering/r1cs_lowering_pass.llzk index 4acecae3ff..061dd4dda3 100644 --- a/test/Transforms/R1CSLowering/r1cs_lowering_pass.llzk +++ b/test/Transforms/R1CSLowering/r1cs_lowering_pass.llzk @@ -1,6 +1,6 @@ // RUN: llzk-opt -split-input-file -llzk-full-r1cs-lowering -verify-diagnostics %s | FileCheck --enable-var-scope %s -module attributes {llzk.lang} { +module attributes {llzk.lang, llzk.main = !struct.type<@CmpConstraint>} { // lowers constraints to be at most degree 2 polynomials struct.def @CmpConstraint { struct.member @val: !felt.type {llzk.pub} @@ -46,7 +46,7 @@ module attributes {llzk.lang} { // CHECK-DAG: } // ----- -module attributes {llzk.lang} { +module attributes {llzk.lang, llzk.main = !struct.type<@CmpConstraint>} { // lowers constraints to be at most degree 2 polynomials struct.def @CmpConstraint { struct.member @val: !felt.type {llzk.pub} diff --git a/test/Transforms/R1CSLowering/r1cs_lowering_quadratic_linear_sign.llzk b/test/Transforms/R1CSLowering/r1cs_lowering_quadratic_linear_sign.llzk index 6e0de2ce43..8e982496d6 100644 --- a/test/Transforms/R1CSLowering/r1cs_lowering_quadratic_linear_sign.llzk +++ b/test/Transforms/R1CSLowering/r1cs_lowering_quadratic_linear_sign.llzk @@ -1,6 +1,6 @@ // RUN: llzk-opt -split-input-file -llzk-full-r1cs-lowering -verify-diagnostics %s | FileCheck --enable-var-scope %s -module attributes {llzk.lang} { +module attributes {llzk.lang, llzk.main = !struct.type<@MulOut>} { struct.def @MulOut { struct.member @lhs_out : !felt.type<"babybear"> {llzk.pub} struct.member @rhs_out : !felt.type<"babybear"> {llzk.pub} diff --git a/test/Transforms/R1CSLowering/r1cs_lowering_typed_aux_member.llzk b/test/Transforms/R1CSLowering/r1cs_lowering_typed_aux_member.llzk index 3a0e0abaee..2def6d7ec9 100644 --- a/test/Transforms/R1CSLowering/r1cs_lowering_typed_aux_member.llzk +++ b/test/Transforms/R1CSLowering/r1cs_lowering_typed_aux_member.llzk @@ -1,6 +1,6 @@ // RUN: llzk-opt -split-input-file -llzk-full-r1cs-lowering --verify-each %s | FileCheck --enable-var-scope %s -module attributes {llzk.lang} { +module attributes {llzk.lang, llzk.main = !struct.type<@TypedR1CSAux>} { struct.def @TypedR1CSAux { struct.member @out : !felt.type<"babybear"> {llzk.pub} diff --git a/test/Transforms/RedundantAndUnusedElim/unused_decl_after_redundant_elim.llzk b/test/Transforms/RedundantAndUnusedElim/unused_decl_after_redundant_elim.llzk new file mode 100644 index 0000000000..ba9463d4a9 --- /dev/null +++ b/test/Transforms/RedundantAndUnusedElim/unused_decl_after_redundant_elim.llzk @@ -0,0 +1,71 @@ +// RUN: llzk-opt -llzk-unused-declaration-elim -llzk-duplicate-read-write-elim -llzk-duplicate-op-elim %s 2>&1 | FileCheck %s --check-prefix BASE +// RUN: llzk-opt -llzk-unused-declaration-elim -llzk-duplicate-read-write-elim -llzk-duplicate-op-elim -llzk-unused-declaration-elim %s 2>&1 | FileCheck %s --check-prefix EXTRA + +// TESTS: running `-llzk-unused-declaration-elim` after the duplicate removal passes can elminate something that could not be eliminated before. +module attributes {llzk.lang} { + struct.def @A { + function.def @compute() -> !struct.type<@A> { + %self = struct.new : !struct.type<@A> + function.return %self : !struct.type<@A> + } + function.def @constrain(%self : !struct.type<@A>) { + function.return + } + } + + struct.def @B { + struct.member @a : !struct.type<@A> + function.def @compute() -> !struct.type<@B> { + %self = struct.new : !struct.type<@B> + function.return %self : !struct.type<@B> + } + function.def @constrain(%self : !struct.type<@B>) { + %a = struct.readm %self[@a] : !struct.type<@B>, !struct.type<@A> + function.call @A::@constrain(%a) : (!struct.type<@A>) -> () + function.return + } + } +} +// BASE-LABEL: module attributes {llzk.lang} { +// BASE-NEXT: struct.def @A { +// BASE-NEXT: function.def @compute() -> !struct.type<@A> attributes {function.allow_witness} { +// BASE-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@A> +// BASE-NEXT: function.return %[[VAL_0]] : !struct.type<@A> +// BASE-NEXT: } +// BASE-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@A>) attributes {function.allow_constraint} { +// BASE-NEXT: function.return +// BASE-NEXT: } +// BASE-NEXT: } +// BASE-NEXT: struct.def @B { +// BASE-NEXT: struct.member @a : !struct.type<@A> +// BASE-NEXT: function.def @compute() -> !struct.type<@B> attributes {function.allow_witness} { +// BASE-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@B> +// BASE-NEXT: function.return %[[VAL_2]] : !struct.type<@B> +// BASE-NEXT: } +// BASE-NEXT: function.def @constrain(%[[VAL_3:[0-9a-zA-Z_\.]+]]: !struct.type<@B>) attributes {function.allow_constraint} { +// BASE-NEXT: function.return +// BASE-NEXT: } +// BASE-NEXT: } +// BASE-NEXT: } + +// EXTRA-LABEL: module attributes {llzk.lang} { +// EXTRA-NEXT: struct.def @A { +// EXTRA-NEXT: function.def @compute() -> !struct.type<@A> attributes {function.allow_witness} { +// EXTRA-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@A> +// EXTRA-NEXT: function.return %[[VAL_0]] : !struct.type<@A> +// EXTRA-NEXT: } +// EXTRA-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@A>) attributes {function.allow_constraint} { +// EXTRA-NEXT: function.return +// EXTRA-NEXT: } +// EXTRA-NEXT: } +// EXTRA-NEXT: struct.def @B { +// COM: REMOVED: struct.member @a : !struct.type<@A> +// EXTRA-NEXT: function.def @compute() -> !struct.type<@B> attributes {function.allow_witness} { +// EXTRA-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = struct.new : <@B> +// EXTRA-NEXT: function.return %[[VAL_2]] : !struct.type<@B> +// EXTRA-NEXT: } +// EXTRA-NEXT: function.def @constrain(%[[VAL_3:[0-9a-zA-Z_\.]+]]: !struct.type<@B>) attributes {function.allow_constraint} { +// EXTRA-NEXT: function.return +// EXTRA-NEXT: } +// EXTRA-NEXT: } +// EXTRA-NEXT: } diff --git a/test/Transforms/RedundantAndUnusedElim/unused_declaration_pass.llzk b/test/Transforms/RedundantAndUnusedElim/unused_declaration_pass.llzk index f6ec7a4ee7..ed2c67f2c8 100644 --- a/test/Transforms/RedundantAndUnusedElim/unused_declaration_pass.llzk +++ b/test/Transforms/RedundantAndUnusedElim/unused_declaration_pass.llzk @@ -371,19 +371,15 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } -// STRICT-LABEL: module @nested { -// STRICT-NEXT: } - -// STRICT-LABEL: module @user { -// STRICT-NEXT: } - -// STRICT-LABEL: struct.def @Main { -// STRICT-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// STRICT-NEXT: %[[VAL_0:.*]] = struct.new : <@Main> -// STRICT-NEXT: function.return %[[VAL_0]] : !struct.type<@Main> -// STRICT-NEXT: } -// STRICT-NEXT: function.def @constrain(%[[VAL_1:.*]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// STRICT-NEXT: function.return +// STRICT-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// STRICT-NEXT: struct.def @Main { +// STRICT-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { +// STRICT-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// STRICT-NEXT: function.return %[[VAL_0]] : !struct.type<@Main> +// STRICT-NEXT: } +// STRICT-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// STRICT-NEXT: function.return +// STRICT-NEXT: } // STRICT-NEXT: } // STRICT-NEXT: } @@ -479,19 +475,15 @@ module attributes {llzk.main = !struct.type<@Main>, llzk.lang} { } } -// STRICT-LABEL: module @nested { -// STRICT-NEXT: } - -// STRICT-LABEL: module @user { -// STRICT-NEXT: } - -// STRICT-LABEL: struct.def @Main { -// STRICT-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { -// STRICT-NEXT: %[[VAL_0:.*]] = struct.new : <@Main> -// STRICT-NEXT: function.return %[[VAL_0]] : !struct.type<@Main> -// STRICT-NEXT: } -// STRICT-NEXT: function.def @constrain(%[[VAL_1:.*]]: !struct.type<@Main>) attributes {function.allow_constraint} { -// STRICT-NEXT: function.return +// STRICT-LABEL: module attributes {llzk.lang, llzk.main = !struct.type<@Main>} { +// STRICT-NEXT: struct.def @Main { +// STRICT-NEXT: function.def @compute() -> !struct.type<@Main> attributes {function.allow_witness} { +// STRICT-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@Main> +// STRICT-NEXT: function.return %[[VAL_0]] : !struct.type<@Main> +// STRICT-NEXT: } +// STRICT-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@Main>) attributes {function.allow_constraint} { +// STRICT-NEXT: function.return +// STRICT-NEXT: } // STRICT-NEXT: } // STRICT-NEXT: } diff --git a/tools/llzk-opt/llzk-opt.cpp b/tools/llzk-opt/llzk-opt.cpp index 50be5ecbc6..6f503b0c70 100644 --- a/tools/llzk-opt/llzk-opt.cpp +++ b/tools/llzk-opt/llzk-opt.cpp @@ -15,6 +15,7 @@ #include "r1cs/Dialect/IR/Dialect.h" #include "r1cs/DialectRegistration.h" +#include "r1cs/Transforms/TransformationPassPipelines.h" #include "r1cs/Transforms/TransformationPasses.h" #include "smt/Conversions/ConversionPasses.h" #include "tools/config.h" @@ -30,6 +31,8 @@ #include "llzk/Dialect/InitDialects.h" #include "llzk/Dialect/POD/Transforms/TransformationPasses.h" #include "llzk/Dialect/Polymorphic/Transforms/TransformationPasses.h" +#include "llzk/Dialect/Struct/Transforms/TransformationPasses.h" +#include "llzk/Transforms/LLZKTransformationPassPipelines.h" #include "llzk/Transforms/LLZKTransformationPasses.h" #include "llzk/Transforms/SpecializedMemoryPasses.h" #include "llzk/Validators/LLZKValidationPasses.h" @@ -82,9 +85,7 @@ inline static void registerTransformsPasses() { mlir::registerMem2Reg(); mlir::registerPrintIRPass(); mlir::registerPrintOpStats(); - mlir::registerPass([]() -> std::unique_ptr { - return llzk::createRemoveDeadValuesWorkaroundPass(); - }); + mlir::registerPass(llzk::createRemoveDeadValuesWorkaroundPass); mlir::registerSCCP(); mlir::registerSROA(); mlir::registerStripDebugInfo(); @@ -122,6 +123,7 @@ int main(int argc, char **argv) { llzk::registerValidationPasses(); llzk::registerAnalysisPasses(); llzk::registerTransformationPasses(); + llzk::component::registerTransformationPasses(); llzk::array::registerTransformationPasses(); llzk::include::registerTransformationPasses(); llzk::polymorphic::registerTransformationPasses(); @@ -133,10 +135,10 @@ int main(int argc, char **argv) { pcl::registerTransformationPasses(); pcl::conversion::registerPCLTransformationPasses(); #endif // LLZK_WITH_PCL + llzk::smt::registerConversionPasses(); llzk::registerTransformationPassPipelines(); r1cs::registerTransformationPassPipelines(); - llzk::smt::registerConversionPasses(); // Register and parse command line options. std::string inputFilename, outputFilename; From 3196c35dc8545ba4545f9d6e83bc288152af003b Mon Sep 17 00:00:00 2001 From: Timothy Hoffman <4001421+tim-hoffman@users.noreply.github.com> Date: Thu, 18 Jun 2026 16:58:46 -0500 Subject: [PATCH 08/12] use `StringAttr` instead of `FlatSymbolRefAttr` in pod access ops (#564) --- .../th__change_record_name_to_str.yaml | 2 + include/llzk/Dialect/POD/IR/OpInterfaces.td | 7 +-- include/llzk/Dialect/POD/IR/Ops.h | 4 +- include/llzk/Dialect/POD/IR/Ops.td | 6 +- lib/Dialect/Function/IR/Ops.cpp | 2 +- lib/Dialect/POD/IR/Ops.cpp | 61 +++++++++++++++++-- .../POD/Transforms/PodToScalarPass.cpp | 48 ++++++--------- lib/Dialect/Struct/IR/Ops.cpp | 2 +- lib/Util/SymbolTableLLZK.cpp | 21 +------ unittests/CAPI/Dialect/POD.cpp | 4 +- 10 files changed, 88 insertions(+), 69 deletions(-) create mode 100644 changelogs/unreleased/th__change_record_name_to_str.yaml diff --git a/changelogs/unreleased/th__change_record_name_to_str.yaml b/changelogs/unreleased/th__change_record_name_to_str.yaml new file mode 100644 index 0000000000..025bf38ac7 --- /dev/null +++ b/changelogs/unreleased/th__change_record_name_to_str.yaml @@ -0,0 +1,2 @@ +changed: + - use `StringAttr` instead of `FlatSymbolRefAttr` for record name in pod read/write ops diff --git a/include/llzk/Dialect/POD/IR/OpInterfaces.td b/include/llzk/Dialect/POD/IR/OpInterfaces.td index 2c5b97d67b..adcccd58bf 100644 --- a/include/llzk/Dialect/POD/IR/OpInterfaces.td +++ b/include/llzk/Dialect/POD/IR/OpInterfaces.td @@ -49,18 +49,13 @@ def PodAccessOpInterface // Requires implementors to have a '$record_name' attribute. InterfaceMethod< [{Gets the record name attribute from the pod access op.}], - "::mlir::FlatSymbolRefAttr", "getRecordNameAttr", (ins)>, + "::mlir::StringAttr", "getRecordNameAttr", (ins)>, InterfaceMethod< [{Return `true` if the op is a read, `false` if it's a write.}], "bool", "isRead", (ins)>, ]; let extraClassDeclaration = [{ - /// Gets the record name as an attribute suitable for destructuring indices. - inline ::mlir::StringAttr getRecordNameAsStringAttr() { - return getRecordNameAttr().getLeafReference(); - } - /// Required by companion interface DestructurableAccessorOpInterface / SROA pass bool canRewire(const ::mlir::DestructurableMemorySlot &slot, ::llvm::SmallPtrSetImpl<::mlir::Attribute> &usedIndices, diff --git a/include/llzk/Dialect/POD/IR/Ops.h b/include/llzk/Dialect/POD/IR/Ops.h index 9997325174..2b890fa007 100644 --- a/include/llzk/Dialect/POD/IR/Ops.h +++ b/include/llzk/Dialect/POD/IR/Ops.h @@ -32,7 +32,7 @@ namespace llzk::pod { mlir::SmallVector getInitializedRecordValues(mlir::ValueRange initialValues, mlir::ArrayAttr initializedRecords); -mlir::ParseResult parseRecordName(mlir::AsmParser &parser, mlir::FlatSymbolRefAttr &name); -void printRecordName(mlir::AsmPrinter &printer, mlir::Operation *, mlir::FlatSymbolRefAttr name); +mlir::ParseResult parseRecordName(mlir::AsmParser &parser, mlir::StringAttr &name); +void printRecordName(mlir::AsmPrinter &printer, mlir::Operation *, mlir::StringAttr name); } // namespace llzk::pod diff --git a/include/llzk/Dialect/POD/IR/Ops.td b/include/llzk/Dialect/POD/IR/Ops.td index f94976084b..0442abd595 100644 --- a/include/llzk/Dialect/POD/IR/Ops.td +++ b/include/llzk/Dialect/POD/IR/Ops.td @@ -239,11 +239,12 @@ def LLZK_ReadPodOp : ScalarPODAccessOp<"read", 1> { }]; let arguments = (ins Arg:$pod_ref, - FlatSymbolRefAttr:$record_name); + StrAttr:$record_name); let results = (outs AnyLLZKType:$result); let assemblyFormat = [{ $pod_ref `[` custom($record_name) `]` `:` type($pod_ref) `,` type($result) attr-dict }]; + let useCustomPropertiesEncoding = 1; let hasVerifier = 1; } @@ -264,11 +265,12 @@ def LLZK_WritePodOp : ScalarPODAccessOp<"write", 0> { }]; let arguments = (ins Arg:$pod_ref, - FlatSymbolRefAttr:$record_name, AnyLLZKType:$value); + StrAttr:$record_name, AnyLLZKType:$value); let assemblyFormat = [{ $pod_ref `[` custom($record_name) `]` `=` $value `:` type($pod_ref) `,` type($value) attr-dict }]; + let useCustomPropertiesEncoding = 1; let hasVerifier = 1; } diff --git a/lib/Dialect/Function/IR/Ops.cpp b/lib/Dialect/Function/IR/Ops.cpp index 4c478c97de..f98fdf5513 100644 --- a/lib/Dialect/Function/IR/Ops.cpp +++ b/lib/Dialect/Function/IR/Ops.cpp @@ -552,7 +552,7 @@ LogicalResult CallOp::readProperties(DialectBytecodeReader &reader, OperationSta return success(); } -// Same as tablegen would generate to serialize version 2 IR. +// Same as tablegen would generate to serialize current version IR. void CallOp::writeProperties(DialectBytecodeWriter &writer) { auto &prop = getProperties(); writer.writeAttribute(prop.callee); diff --git a/lib/Dialect/POD/IR/Ops.cpp b/lib/Dialect/POD/IR/Ops.cpp index 5485c94bdb..e8065a3e2a 100644 --- a/lib/Dialect/POD/IR/Ops.cpp +++ b/lib/Dialect/POD/IR/Ops.cpp @@ -11,6 +11,7 @@ #include "llzk/Dialect/Array/IR/Types.h" #include "llzk/Dialect/LLZK/IR/Ops.h" +#include "llzk/Dialect/LLZK/IR/Versioning.h" #include "llzk/Dialect/POD/IR/Types.h" #include "llzk/Dialect/Struct/IR/Types.h" #include "llzk/Util/TypeHelper.h" @@ -470,7 +471,7 @@ bool PodAccessOpInterface::canRewire( return false; } - StringAttr recordName = getRecordNameAsStringAttr(); + StringAttr recordName = getRecordNameAttr(); if (!slot.subelementTypes.contains(recordName)) { return false; } @@ -487,7 +488,7 @@ DeletionKind PodAccessOpInterface::rewire( assert(slot.ptr == getPodRef()); assert(slot.elemType == getPodRefType()); - StringAttr recordName = getRecordNameAsStringAttr(); + StringAttr recordName = getRecordNameAttr(); const MemorySlot &memorySlot = subslots.at(recordName); getPodRefMutable().set(memorySlot.ptr); @@ -498,6 +499,42 @@ DeletionKind PodAccessOpInterface::rewire( // ReadPodOp //===----------------------------------------------------------------------===// +namespace { + +LogicalResult readRecordNameProperty(DialectBytecodeReader &reader, StringAttr &recordName) { + auto versionOpt = reader.getDialectVersion(); + if (succeeded(versionOpt)) { + const auto &ver = static_cast(**versionOpt); + if (ver.majorVersion < 3) { + // Prior to v3 it was serialized as a `FlatSymbolRefAttr` instead of a `StringAttr`. + FlatSymbolRefAttr attr; + if (failed(reader.readAttribute(attr))) { + return failure(); + } + recordName = attr.getAttr(); + return success(); + } + } + + // Same as tablegen would generate to deserialize current-version IR. + return reader.readAttribute(recordName); +} + +void writeRecordNameProperty(DialectBytecodeWriter &writer, StringAttr recordName) { + writer.writeAttribute(recordName); +} + +} // namespace + +LogicalResult ReadPodOp::readProperties(DialectBytecodeReader &reader, OperationState &state) { + auto &prop = state.getOrAddProperties(); + return readRecordNameProperty(reader, prop.record_name); +} + +void ReadPodOp::writeProperties(DialectBytecodeWriter &writer) { + writeRecordNameProperty(writer, getProperties().record_name); +} + LogicalResult ReadPodOp::verify() { auto podTy = llvm::dyn_cast(getPodRef().getType()); if (!podTy) { @@ -522,6 +559,15 @@ LogicalResult ReadPodOp::verify() { // WritePodOp //===----------------------------------------------------------------------===// +LogicalResult WritePodOp::readProperties(DialectBytecodeReader &reader, OperationState &state) { + auto &prop = state.getOrAddProperties(); + return readRecordNameProperty(reader, prop.record_name); +} + +void WritePodOp::writeProperties(DialectBytecodeWriter &writer) { + writeRecordNameProperty(writer, getProperties().record_name); +} + LogicalResult WritePodOp::verify() { auto podTy = llvm::dyn_cast(getPodRef().getType()); if (!podTy) { @@ -546,11 +592,16 @@ LogicalResult WritePodOp::verify() { // Parsing/Printing helpers //===----------------------------------------------------------------------===// -ParseResult parseRecordName(AsmParser &parser, FlatSymbolRefAttr &name) { - return parser.parseCustomAttributeWithFallback(name); +ParseResult parseRecordName(AsmParser &parser, StringAttr &name) { + FlatSymbolRefAttr symRef; + auto result = parser.parseCustomAttributeWithFallback(symRef); + if (succeeded(result)) { + name = symRef.getAttr(); + } + return result; } -void printRecordName(AsmPrinter &printer, Operation *, FlatSymbolRefAttr name) { +void printRecordName(AsmPrinter &printer, Operation *, StringAttr name) { printer.printSymbolName(name.getValue()); } diff --git a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp index a44b9af656..4e3778609d 100644 --- a/lib/Dialect/POD/Transforms/PodToScalarPass.cpp +++ b/lib/Dialect/POD/Transforms/PodToScalarPass.cpp @@ -759,24 +759,14 @@ step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementM return applyFullConversion(modOp, target, std::move(patterns)); } -/// Normalize the record name representation used by POD access ops to a plain `StringAttr`. -inline static StringAttr getRecordNameAsStringAttr(ReadPodOp readOp) { - return readOp.getRecordNameAttr().getLeafReference(); -} - -/// Normalize the record name representation used by POD access ops to a plain `StringAttr`. -inline static StringAttr getRecordNameAsStringAttr(WritePodOp writeOp) { - return writeOp.getRecordNameAttr().getLeafReference(); -} - /// Return whether the given read/write access targets the same POD record. inline static bool isSamePodRecord(ReadPodOp readOp, Value podRef, StringAttr recordName) { - return readOp.getPodRef() == podRef && getRecordNameAsStringAttr(readOp) == recordName; + return readOp.getPodRef() == podRef && readOp.getRecordNameAttr() == recordName; } /// Return whether the given read/write access targets the same POD record. inline static bool isSamePodRecord(WritePodOp writeOp, Value podRef, StringAttr recordName) { - return writeOp.getPodRef() == podRef && getRecordNameAsStringAttr(writeOp) == recordName; + return writeOp.getPodRef() == podRef && writeOp.getRecordNameAttr() == recordName; } /// Return whether `op` contains a nested write to `podRef.recordName`. @@ -803,7 +793,7 @@ static bool hasValueUse(Operation &op, Value value) { /// Return whether the read is preceded by a write to the same pod record within its block. static bool hasEarlierWriteInBlock(ReadPodOp readOp) { Value podRef = readOp.getPodRef(); - StringAttr recordName = getRecordNameAsStringAttr(readOp); + StringAttr recordName = readOp.getRecordNameAttr(); for (Operation &op : *readOp->getBlock()) { if (&op == readOp.getOperation()) { @@ -854,7 +844,7 @@ static WritePodOp findPrecedingWriteForIfRead(ReadPodOp readOp) { } Value podRef = readOp.getPodRef(); - StringAttr recordName = getRecordNameAsStringAttr(readOp); + StringAttr recordName = readOp.getRecordNameAttr(); WritePodOp replacement = nullptr; for (Operation &op : *ifBlock) { if (&op == ifOp.getOperation()) { @@ -897,9 +887,8 @@ class ReplaceIfReadPattern final : public OpRewritePattern { rewriter.setInsertionPoint(ifOp); rewriter.replaceOp( - readOp, - genRead(readOp.getLoc(), readOp.getPodRef(), getRecordNameAsStringAttr(readOp), rewriter) - .getResult() + readOp, genRead(readOp.getLoc(), readOp.getPodRef(), readOp.getRecordNameAttr(), rewriter) + .getResult() ); return success(); } @@ -934,7 +923,7 @@ class FoldIfCarriedPodReadAfterWritePattern final : public OpRewritePattern(readOp->getPrevNode()); - if (!writeOp || getRecordNameAsStringAttr(writeOp) != getRecordNameAsStringAttr(readOp)) { + if (!writeOp || writeOp.getRecordNameAttr() != readOp.getRecordNameAttr()) { return failure(); } @@ -1022,7 +1011,7 @@ collectDirectWrites(Block *block, bool isThenBlock, SmallVectorImpl } IfWriteSlot &slot = getOrCreateSlot( - slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp), writeOp.getValue().getType() + slots, writeOp.getPodRef(), writeOp.getRecordNameAttr(), writeOp.getValue().getType() ); if (isThenBlock) { slot.thenWrite = writeOp; @@ -1179,7 +1168,7 @@ collectDirectLoopPodSlots(Block &block, Operation *ancestor, SmallVectorImpl(&op)) { if (!isValueDefinedInside(ancestor, readOp.getPodRef())) { getOrCreateLoopSlot( - slots, readOp.getPodRef(), getRecordNameAsStringAttr(readOp), readOp.getType() + slots, readOp.getPodRef(), readOp.getRecordNameAttr(), readOp.getType() ); } continue; @@ -1188,8 +1177,7 @@ collectDirectLoopPodSlots(Block &block, Operation *ancestor, SmallVectorImpl(&op)) { if (!isValueDefinedInside(ancestor, writeOp.getPodRef())) { getOrCreateLoopSlot( - slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp), - writeOp.getValue().getType() + slots, writeOp.getPodRef(), writeOp.getRecordNameAttr(), writeOp.getValue().getType() ); } } @@ -1214,14 +1202,14 @@ static bool hasNestedTrackedPodAccess(Operation &op, ArrayRef slots } if (auto readOp = dyn_cast(nestedOp)) { - if (hasLoopSlot(slots, readOp.getPodRef(), getRecordNameAsStringAttr(readOp))) { + if (hasLoopSlot(slots, readOp.getPodRef(), readOp.getRecordNameAttr())) { return WalkResult::interrupt(); } return WalkResult::advance(); } if (auto writeOp = dyn_cast(nestedOp)) { - if (hasLoopSlot(slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp))) { + if (hasLoopSlot(slots, writeOp.getPodRef(), writeOp.getRecordNameAttr())) { return WalkResult::interrupt(); } } @@ -1375,7 +1363,7 @@ class LiftPodAccessesFromForLoopPattern final : public OpRewritePattern(&op)) { if (std::optional slotIdx = - findLoopSlotIndex(slots, readOp.getPodRef(), getRecordNameAsStringAttr(readOp))) { + findLoopSlotIndex(slots, readOp.getPodRef(), readOp.getRecordNameAttr())) { mapping.map(readOp.getResult(), slotValues[*slotIdx]); continue; } @@ -1383,7 +1371,7 @@ class LiftPodAccessesFromForLoopPattern final : public OpRewritePattern(&op)) { if (std::optional slotIdx = - findLoopSlotIndex(slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp))) { + findLoopSlotIndex(slots, writeOp.getPodRef(), writeOp.getRecordNameAttr())) { slotValues[*slotIdx] = mapping.lookupOrDefault(writeOp.getValue()); continue; } @@ -1472,7 +1460,7 @@ class LiftPodAccessesFromWhileLoopPattern final : public OpRewritePattern(&op)) { if (std::optional slotIdx = - findLoopSlotIndex(slots, readOp.getPodRef(), getRecordNameAsStringAttr(readOp))) { + findLoopSlotIndex(slots, readOp.getPodRef(), readOp.getRecordNameAttr())) { beforeMapping.map(readOp.getResult(), beforeSlotValues[*slotIdx]); continue; } @@ -1480,7 +1468,7 @@ class LiftPodAccessesFromWhileLoopPattern final : public OpRewritePattern(&op)) { if (std::optional slotIdx = - findLoopSlotIndex(slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp))) { + findLoopSlotIndex(slots, writeOp.getPodRef(), writeOp.getRecordNameAttr())) { beforeSlotValues[*slotIdx] = beforeMapping.lookupOrDefault(writeOp.getValue()); continue; } @@ -1518,7 +1506,7 @@ class LiftPodAccessesFromWhileLoopPattern final : public OpRewritePattern(&op)) { if (std::optional slotIdx = - findLoopSlotIndex(slots, readOp.getPodRef(), getRecordNameAsStringAttr(readOp))) { + findLoopSlotIndex(slots, readOp.getPodRef(), readOp.getRecordNameAttr())) { afterMapping.map(readOp.getResult(), afterSlotValues[*slotIdx]); continue; } @@ -1526,7 +1514,7 @@ class LiftPodAccessesFromWhileLoopPattern final : public OpRewritePattern(&op)) { if (std::optional slotIdx = - findLoopSlotIndex(slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp))) { + findLoopSlotIndex(slots, writeOp.getPodRef(), writeOp.getRecordNameAttr())) { afterSlotValues[*slotIdx] = afterMapping.lookupOrDefault(writeOp.getValue()); continue; } diff --git a/lib/Dialect/Struct/IR/Ops.cpp b/lib/Dialect/Struct/IR/Ops.cpp index 5e30b312e5..12a7e0bc5b 100644 --- a/lib/Dialect/Struct/IR/Ops.cpp +++ b/lib/Dialect/Struct/IR/Ops.cpp @@ -513,7 +513,7 @@ LogicalResult StructDefOp::readProperties(DialectBytecodeReader &reader, Operati return reader.readAttribute(prop.sym_name); } -// Same as tablegen would generate to serialize version 2 IR. +// Same as tablegen would generate to serialize current version IR. void StructDefOp::writeProperties(DialectBytecodeWriter &writer) { auto &prop = getProperties(); writer.writeAttribute(prop.sym_name); diff --git a/lib/Util/SymbolTableLLZK.cpp b/lib/Util/SymbolTableLLZK.cpp index 139237ed70..c87a724cad 100644 --- a/lib/Util/SymbolTableLLZK.cpp +++ b/lib/Util/SymbolTableLLZK.cpp @@ -35,8 +35,6 @@ #include "llzk/Util/SymbolTableLLZK.h" -#include "llzk/Dialect/POD/IR/Ops.h" - #include using namespace mlir; @@ -149,24 +147,7 @@ walkSymbolRefs(Operation *op, function_ref c return WalkResult::interrupt(); } } - - // TODO: Remove this when POD types are updated to use StringAttr. - // POD record names are encoded as FlatSymbolRefAttr for parsing/printing - // convenience, but they are not real symbol references and must not be - // surfaced as symbol uses. - auto shouldSkipAttr = [op](NamedAttribute attr) { - return attr.getName() == "record_name" && - (isa(op) || isa(op)); - }; - for (NamedAttribute attr : op->getAttrs()) { - if (shouldSkipAttr(attr)) { - continue; - } - if (attr.getValue().walk(walkFn).wasInterrupted()) { - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); + return op->getAttrDictionary().walk(walkFn); } /// Walk all of the uses, for any symbol, that are nested within the given diff --git a/unittests/CAPI/Dialect/POD.cpp b/unittests/CAPI/Dialect/POD.cpp index fc3f482edb..231451a67b 100644 --- a/unittests/CAPI/Dialect/POD.cpp +++ b/unittests/CAPI/Dialect/POD.cpp @@ -198,7 +198,7 @@ std::unique_ptr ReadPodOpBuildFuncHelper::get() { unwrap(testClass.context), {createRecordAttrCpp(name, unwrap(indexTy))} ); auto newPodOp = unwrap(builder)->create(unwrap(location), podTy); - auto recordName = mlir::FlatSymbolRefAttr::get(unwrap(testClass.context), name); + auto recordName = mlir::StringAttr::get(unwrap(testClass.context), name); return llzkPod_ReadPodOpBuild( builder, location, indexTy, wrap(newPodOp.getResult()), wrap(recordName) ); @@ -222,7 +222,7 @@ std::unique_ptr WritePodOpBuildFuncHelper::get() { unwrap(testClass.context), {createRecordAttrCpp(name, unwrap(indexTy))} ); auto newPodOp = unwrap(builder)->create(unwrap(location), podTy); - auto recordName = mlir::FlatSymbolRefAttr::get(unwrap(testClass.context), name); + auto recordName = mlir::StringAttr::get(unwrap(testClass.context), name); return llzkPod_WritePodOpBuild( builder, location, wrap(newPodOp.getResult()), mlirOperationGetResult(testClass.createIndexOperation(), 0), wrap(recordName) From 88333ce2cf125fd0eb3932deada99dfcb96702dd Mon Sep 17 00:00:00 2001 From: "Cyne Jarvis J. Zarceno" Date: Fri, 19 Jun 2026 11:49:50 +0800 Subject: [PATCH 09/12] Clarify aux write ordering comments --- lib/Transforms/LLZKPolyLoweringPass.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lib/Transforms/LLZKPolyLoweringPass.cpp b/lib/Transforms/LLZKPolyLoweringPass.cpp index 35bb9fc175..3ee780e378 100644 --- a/lib/Transforms/LLZKPolyLoweringPass.cpp +++ b/lib/Transforms/LLZKPolyLoweringPass.cpp @@ -86,6 +86,8 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { const llvm::StringMap &auxNameToIndex, DenseSet &visitedValues, DenseSet &seenDeps, SmallVectorImpl &deps ) const { + // Aux dependencies can appear as generated aux SSA values or reads of generated + // aux members, so track both forms before ordering writes. if (!val || !visitedValues.insert(val).second) { return; } @@ -127,6 +129,7 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { } visitState[idx] = 1; + // Emit prerequisite aux writes before the aux writes that read them. for (unsigned dep : deps[idx]) { if (failed(visitAuxAssignment(dep, deps, visitState, ordered, auxAssignments))) { return failure(); @@ -579,6 +582,8 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { rebuiltExpr ); if (assign.auxValue) { + // Reuse the expression just written so later aux producers do not need an + // immediate read from the generated aux member. rebuildMemo[assign.auxValue] = rebuiltExpr; } } From 5d87a3ea0a3a0556887ad0e6b0d06188cd80a9cd Mon Sep 17 00:00:00 2001 From: foldr Date: Fri, 19 Jun 2026 11:08:11 +0200 Subject: [PATCH 10/12] Loop invariant operations (#540) * Invariant op completed * Inner invariant operations * Changelog * Move new tests to a separate file * Remove trailing newline Co-authored-by: Ian Neal * Add remaining tests * Fix formatting * PR feedback * PR feedback * Fix formatting --------- Co-authored-by: Ian Neal --- .../unreleased/dani__verif-invariant-ops.yaml | 2 + include/llzk-c/Dialect/Verif.h | 16 ++ include/llzk/Dialect/Function/IR/Ops.h | 1 + include/llzk/Dialect/Function/IR/Ops.td | 12 +- include/llzk/Dialect/Shared/Builders.h | 12 +- include/llzk/Dialect/Struct/IR/Ops.h | 1 + include/llzk/Dialect/Struct/IR/Ops.td | 11 +- include/llzk/Dialect/Verif/IR/Dialect.h | 11 + include/llzk/Dialect/Verif/IR/OpInterfaces.h | 24 ++ include/llzk/Dialect/Verif/IR/OpInterfaces.td | 47 ++++ include/llzk/Dialect/Verif/IR/Ops.h | 10 +- include/llzk/Dialect/Verif/IR/Ops.td | 174 ++++++++++++ lib/CAPI/Dialect/Verif.cpp | 27 ++ lib/Dialect/InitDialects.cpp | 1 + lib/Dialect/Shared/Builders.cpp | 13 +- lib/Dialect/Verif/IR/Dialect.cpp | 89 +++++++ lib/Dialect/Verif/IR/Ops.cpp | 169 ++++++++++++ test/Dialect/Verif/invariant_inners_fail.llzk | 101 +++++++ test/Dialect/Verif/invariant_inners_pass.llzk | 174 ++++++++++++ test/Dialect/Verif/invariants_fail.llzk | 105 ++++++++ test/Dialect/Verif/invariants_pass.llzk | 248 ++++++++++++++++++ unittests/CAPI/Dialect/Verif.cpp | 203 +++++++++++++- 22 files changed, 1426 insertions(+), 25 deletions(-) create mode 100644 changelogs/unreleased/dani__verif-invariant-ops.yaml create mode 100644 include/llzk/Dialect/Verif/IR/OpInterfaces.h create mode 100644 test/Dialect/Verif/invariant_inners_fail.llzk create mode 100644 test/Dialect/Verif/invariant_inners_pass.llzk create mode 100644 test/Dialect/Verif/invariants_fail.llzk create mode 100644 test/Dialect/Verif/invariants_pass.llzk diff --git a/changelogs/unreleased/dani__verif-invariant-ops.yaml b/changelogs/unreleased/dani__verif-invariant-ops.yaml new file mode 100644 index 0000000000..2c29bdb6de --- /dev/null +++ b/changelogs/unreleased/dani__verif-invariant-ops.yaml @@ -0,0 +1,2 @@ +added: + - Operations in the `verif` dialect for expression loop invariants diff --git a/include/llzk-c/Dialect/Verif.h b/include/llzk-c/Dialect/Verif.h index 7314addb45..53b01e8730 100644 --- a/include/llzk-c/Dialect/Verif.h +++ b/include/llzk-c/Dialect/Verif.h @@ -33,6 +33,11 @@ extern "C" { /// Get reference to the LLZK `verif` dialect. MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Verif, llzk__verif); +/// Attaches the interfaces defined by the `verif` dialect to upstream IR elements +/// +/// Attempting to use those interfaces without calling this function first will result in an error. +MLIR_CAPI_EXPORTED void llzkVerif_attachInterfaces(MlirContext context); + /// Build a `verif.contract` from explicit attributes and signature metadata. LLZK_DECLARE_OP_BUILD_METHOD( Verif, ContractOp, MlirIdentifier sym_name, MlirAttribute target, MlirAttribute function_type, @@ -55,6 +60,17 @@ LLZK_DECLARE_OP_BUILD_METHOD( Verif, IncludeOp, MlirAttribute callee, MlirValueRange argOperands, MlirAttribute templateParams ); +/// Build a `verif.invariant` with a list of argument types and locations. +/// +/// The pointers to the types and locations must point to the same amount of elements. +LLZK_DECLARE_OP_BUILD_METHOD( + Verif, InvariantOp, MlirStringRef loopName, intptr_t numArgs, MlirType const *types, + MlirLocation const *locs +); + +/// Returns the body of the invariant operation. +MLIR_CAPI_EXPORTED MlirBlock llzkVerif_InvariantOpGetBody(MlirOperation op); + #ifdef __cplusplus } #endif diff --git a/include/llzk/Dialect/Function/IR/Ops.h b/include/llzk/Dialect/Function/IR/Ops.h index ae30f6d453..c221cf0ed6 100644 --- a/include/llzk/Dialect/Function/IR/Ops.h +++ b/include/llzk/Dialect/Function/IR/Ops.h @@ -13,6 +13,7 @@ #include "llzk/Dialect/Polymorphic/IR/Ops.h" #include "llzk/Dialect/Shared/OpHelpers.h" #include "llzk/Dialect/Struct/IR/Ops.h" +#include "llzk/Dialect/Verif/IR/OpInterfaces.h" #include "llzk/Util/Constants.h" #include "llzk/Util/SymbolHelper.h" diff --git a/include/llzk/Dialect/Function/IR/Ops.td b/include/llzk/Dialect/Function/IR/Ops.td index edc1794962..80b1b4cd5d 100644 --- a/include/llzk/Dialect/Function/IR/Ops.td +++ b/include/llzk/Dialect/Function/IR/Ops.td @@ -16,6 +16,7 @@ #define LLZK_FUNC_OPS include "llzk/Dialect/Function/IR/Dialect.td" +include "llzk/Dialect/Verif/IR/OpInterfaces.td" include "llzk/Dialect/Shared/OpTraits.td" include "llzk/Dialect/Shared/Types.td" @@ -36,11 +37,12 @@ class FunctionDialectOp traits = []> def FuncDefOp : FunctionDialectOp< - "def", - [ParentOneOf<["::mlir::ModuleOp", "::llzk::component::StructDefOp", - "::llzk::polymorphic::TemplateOp"]>, - DeclareOpInterfaceMethods, AffineScope, - AutomaticAllocationScope, FunctionOpInterface, IsolatedFromAbove]> { + "def", [ParentOneOf<["::mlir::ModuleOp", + "::llzk::component::StructDefOp", + "::llzk::polymorphic::TemplateOp"]>, + DeclareOpInterfaceMethods, AffineScope, + AutomaticAllocationScope, FunctionOpInterface, + IsolatedFromAbove, ContractTarget]> { // NOTE: Cannot have SymbolTable trait because that would cause global // functions without a body to produce "Operations with a 'SymbolTable' must // have exactly one block" diff --git a/include/llzk/Dialect/Shared/Builders.h b/include/llzk/Dialect/Shared/Builders.h index 7deec30135..b15d32202a 100644 --- a/include/llzk/Dialect/Shared/Builders.h +++ b/include/llzk/Dialect/Shared/Builders.h @@ -285,9 +285,15 @@ template class ModuleLikeBuilder : public BaseBuilder { return insertConstrainCall(caller, callee, getUnknownLoc(), getUnknownLoc()); } - Derived &insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc); - inline Derived &insertFreeFunc(std::string_view funcName, ::mlir::FunctionType type) { - return insertFreeFunc(funcName, type, getUnknownLoc()); + Derived &insertFreeFunc( + std::string_view funcName, ::mlir::FunctionType type, mlir::Location loc, + llvm::function_ref fnBody = nullptr + ); + inline Derived &insertFreeFunc( + std::string_view funcName, ::mlir::FunctionType type, + llvm::function_ref fnBody = nullptr + ) { + return insertFreeFunc(funcName, type, getUnknownLoc(), fnBody); } Derived & diff --git a/include/llzk/Dialect/Struct/IR/Ops.h b/include/llzk/Dialect/Struct/IR/Ops.h index af82a77168..f73115b299 100644 --- a/include/llzk/Dialect/Struct/IR/Ops.h +++ b/include/llzk/Dialect/Struct/IR/Ops.h @@ -14,6 +14,7 @@ #include "llzk/Dialect/Polymorphic/IR/Ops.h" #include "llzk/Dialect/Shared/OpHelpers.h" #include "llzk/Dialect/Struct/IR/Types.h" +#include "llzk/Dialect/Verif/IR/OpInterfaces.h" namespace llzk { diff --git a/include/llzk/Dialect/Struct/IR/Ops.td b/include/llzk/Dialect/Struct/IR/Ops.td index 55d57bc394..650a0ea3d9 100644 --- a/include/llzk/Dialect/Struct/IR/Ops.td +++ b/include/llzk/Dialect/Struct/IR/Ops.td @@ -21,6 +21,7 @@ include "llzk/Dialect/Struct/IR/Dialect.td" include "llzk/Dialect/Struct/IR/OpInterfaces.td" include "llzk/Dialect/Struct/IR/Types.td" include "llzk/Dialect/Shared/OpTraits.td" +include "llzk/Dialect/Verif/IR/OpInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/RegionKindInterface.td" @@ -40,11 +41,11 @@ def SetFuncAllowAttrs : NativeOpTrait<"SetFuncAllowAttrs">, StructuralOpTrait { //===------------------------------------------------------------------===// def LLZK_StructDefOp - : StructDialectOp< - "def", [ParentOneOf<["::mlir::ModuleOp", - "::llzk::polymorphic::TemplateOp"]>, - Symbol, LLZKSymbolTable, IsolatedFromAbove, SetFuncAllowAttrs, - NoRegionArguments]#GraphRegionNoTerminator.traits> { + : StructDialectOp<"def", [ParentOneOf<["::mlir::ModuleOp", + "::llzk::polymorphic::TemplateOp"]>, + Symbol, LLZKSymbolTable, IsolatedFromAbove, + SetFuncAllowAttrs, NoRegionArguments, + ContractTarget]#GraphRegionNoTerminator.traits> { let summary = "circuit component definition"; let description = [{ This operation describes a component in a circuit. It can contain any number diff --git a/include/llzk/Dialect/Verif/IR/Dialect.h b/include/llzk/Dialect/Verif/IR/Dialect.h index 77747d350a..4922b4556b 100644 --- a/include/llzk/Dialect/Verif/IR/Dialect.h +++ b/include/llzk/Dialect/Verif/IR/Dialect.h @@ -10,6 +10,17 @@ #pragma once #include +#include // Include TableGen'd declarations #include "llzk/Dialect/Verif/IR/Dialect.h.inc" + +namespace llzk::verif { +/// Attaches the interfaces defined by the `verif` dialect to upstream IR elements. +/// +/// Attempting to use those interfaces without calling this function first will result in an error. +void attachInterfaces(mlir::MLIRContext &context); + +/// Registers dialect extensions for the verif dialect. +void registerExtensions(mlir::DialectRegistry ®istry); +} // namespace llzk::verif diff --git a/include/llzk/Dialect/Verif/IR/OpInterfaces.h b/include/llzk/Dialect/Verif/IR/OpInterfaces.h new file mode 100644 index 0000000000..5539762a3b --- /dev/null +++ b/include/llzk/Dialect/Verif/IR/OpInterfaces.h @@ -0,0 +1,24 @@ +//===-- OpInterfaces.h ------------------------------------------*- C++ -*-===// +// +// Part of the LLZK Project, under the Apache License v2.0. +// See LICENSE.txt for license information. +// Copyright 2026 Project LLZK. +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "llzk/Dialect/Polymorphic/IR/Ops.h" +#include "llzk/Dialect/Shared/OpHelpers.h" +#include "llzk/Dialect/Verif/IR/Dialect.h" +#include "llzk/Util/TypeHelper.h" +#include "llzk/Util/Walk.h" + +#include +#include +#include +#include + +// Include TableGen'd declarations +#include "llzk/Dialect/Verif/IR/OpInterfaces.h.inc" diff --git a/include/llzk/Dialect/Verif/IR/OpInterfaces.td b/include/llzk/Dialect/Verif/IR/OpInterfaces.td index 26ee48abc9..98074ba7f3 100644 --- a/include/llzk/Dialect/Verif/IR/OpInterfaces.td +++ b/include/llzk/Dialect/Verif/IR/OpInterfaces.td @@ -13,6 +13,7 @@ include "mlir/IR/Interfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/MemorySlotInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" def ConditionOpInterface : OpInterface<"ConditionOpInterface", [DeclareOpInterfaceMethods< @@ -48,4 +49,50 @@ def PostconditionOpInterface let cppNamespace = "::llzk::verif"; } +def InvariantTarget : OpInterface<"InvariantTargetOpInterface"> { + let description = [{ + Interface implemented by operations that can be targeted by a loop invariant. + + Targets are referenced by a label that implementations of this interface must provide. + }]; + let cppNamespace = "::llzk::verif"; + + let methods = + [InterfaceMethod<[{Gets the label of the loop, if available.}], + "::mlir::FailureOr<::mlir::StringRef>", "getLabel", + (ins), [{ + auto attr = ::mlir::dyn_cast_or_null<::mlir::StringAttr>($_op->getDiscardableAttr("loop_label")); + if (!attr) return ::mlir::failure(); + return attr.getValue(); + }]>, + InterfaceMethod< + [{Gets the types of the values that the invariant binds inside its body.}], + "::mlir::SmallVector<::mlir::Type>", "getArgumentTypes", (ins)>, + ]; +} + +def ContractTarget : OpInterface<"ContractTargetOpInterface", [Symbol]> { + let description = [{ + Interface implemented by operations that can be targeted by a contract. + + Targets are referenced by name so implementations of this interface must also implement the `Symbol` interface. + }]; + let cppNamespace = "::llzk::verif"; + + let methods = [InterfaceMethod< + [{Gets the ops that can be targeted by invariant ops.}], + "::llvm::SmallVector<::llzk::verif::InvariantTargetOpInterface>", + "getLoops", (ins), [{ + auto op = $_op; // walkCollect expects an l-value but the expansion of $_op is not. + return walkCollect<::llzk::verif::InvariantTargetOpInterface>(op); + }]>]; + + let extraClassDeclaration = [{ + /// Emulates the static method with the same name found in operations. + static ::mlir::StringLiteral getOperationName() { + return ::mlir::StringLiteral("contract target interface implementation"); + } + }]; +} + #endif // LLZK_VERIF_OP_INTERFACES diff --git a/include/llzk/Dialect/Verif/IR/Ops.h b/include/llzk/Dialect/Verif/IR/Ops.h index eed786d9d1..ada75bb6cd 100644 --- a/include/llzk/Dialect/Verif/IR/Ops.h +++ b/include/llzk/Dialect/Verif/IR/Ops.h @@ -9,20 +9,14 @@ #pragma once +#include "llzk/Dialect/Felt/IR/Types.h" #include "llzk/Dialect/Function/IR/Ops.h" #include "llzk/Dialect/Polymorphic/IR/Ops.h" #include "llzk/Dialect/Shared/OpHelpers.h" #include "llzk/Dialect/Verif/IR/Dialect.h" +#include "llzk/Dialect/Verif/IR/OpInterfaces.h" #include "llzk/Util/TypeHelper.h" -#include -#include -#include -#include - -// Include TableGen'd declarations -#include "llzk/Dialect/Verif/IR/OpInterfaces.h.inc" - // Include TableGen'd declarations #define GET_OP_CLASSES #include "llzk/Dialect/Verif/IR/Ops.h.inc" diff --git a/include/llzk/Dialect/Verif/IR/Ops.td b/include/llzk/Dialect/Verif/IR/Ops.td index a633dc3416..ba2b995872 100644 --- a/include/llzk/Dialect/Verif/IR/Ops.td +++ b/include/llzk/Dialect/Verif/IR/Ops.td @@ -16,6 +16,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" include "llzk/Dialect/Shared/OpTraits.td" include "llzk/Dialect/Shared/Types.td" +include "llzk/Dialect/Felt/IR/Types.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/OpBase.td" @@ -287,6 +288,15 @@ def ContractOp return getFuncTarget(tables); } + /// Return the operation that this contract targets, or failure if it does not + /// target an operation that implements the `ContractTarget` op interface or is not found. + ::mlir::FailureOr> getTargetOp(::mlir::SymbolTableCollection &tables); + + ::mlir::FailureOr> getTargetOp() { + ::mlir::SymbolTableCollection tables; + return getTargetOp(tables); + } + private: /// Populate a builder-created contract body with one entry block matching /// the function signature and insert the implicit `verif.contract_end`. @@ -481,4 +491,168 @@ def IncludeOp }]; } +//===------------------------------------------------------------------===// +// InvariantOp +//===------------------------------------------------------------------===// + +def InvariantOp + : VerifDialectOp<"invariant", [HasAncestor<"::llzk::verif::ContractOp">, + NoTerminator, SingleBlock]> { + + let summary = "loop invariant definition operation"; + let description = [{ + Defines an invariant for a loop inside the target of a contract. + + The targeted loop op must implement the `InvariantTargetOpInterface`. + This interface is already implemented by the `scf.while` and `scf.for` + operations. The loop name is defined on either of those ops with a + string attribute named `loop_label`. + + The arguments of the body must match those declared by the target op. + In the case of `scf.for` the control values must also be block arguments + in the following order: lower bound, induction variable, upper bound, and stride. + + Example: + ``` + module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + %0 = felt.const 5 + %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%2 = %0) -> !felt.type { + scf.yield %2 : !felt.type + } {loop_label = "loopA"} + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + verif.invariant for @loopA(%lb: index, %iv: index, %ub: index, %step: index, %extra: !felt.type) { + %true = arith.constant true + verif.require_compute %true + } + } + } + ``` + }]; + + let arguments = (ins StrAttr:$loop_name, TypeArrayAttr:$loop_arg_types); + + let regions = (region SizedRegion<1>:$region); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder<(ins "::mlir::StringRef":$loop_name, + CArg<"::llvm::ArrayRef<::mlir::Type>", "{}">:$loop_arg_types, + CArg<"::llvm::ArrayRef<::mlir::Location>", "{}">:$loop_arg_locs)>]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Returns the contract operation that contains this invariant. + ::llzk::verif::ContractOp getParentContract(); + /// Returns the loop target. + ::mlir::FailureOr<::llzk::verif::InvariantTargetOpInterface> getTarget(); + }]; +} + +//===------------------------------------------------------------------===// +// Invariant inner ops +//===------------------------------------------------------------------===// + +class InvariantInnerOp traits = []> + : VerifDialectOp< + mnemonic, + traits#[HasAncestor<"::llzk::verif::InvariantOp">, + DeclareOpInterfaceMethods]> { + + let summary = mnemonic#" value operation"; + let description = [{ + Indicates that a value }]#mnemonic#[{ on each iteration of the loop. + + This declares the `MemoryEffectsOpInterface`, which, like the `cf.assert` (MLIR `cf` dialect) + and `bool.assert` (LLZK `bool` dialect) ops, adds a MemWrite affect to model program termination. + }]; + + let arguments = (ins LLZK_FeltType:$value); + + let extraClassDefinition = [{ + // This side effect models "program termination". Based on + // https://github.com/llvm/llvm-project/blob/f325e4b2d836d6e65a4d0cf3efc6b0996ccf3765/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp#L92-L97 + void $cppClass::getEffects( + ::mlir::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects + ) { + effects.emplace_back(::mlir::MemoryEffects::Write::get()); + } + }]; + + let assemblyFormat = "$value attr-dict"; +} + +def IncreasesOp : InvariantInnerOp<"increases"> {} + +def DecreasesOp : InvariantInnerOp<"decreases"> {} + +def StepOp : VerifDialectOp< + "step", [HasAncestor<"::llzk::verif::InvariantOp">, + DeclareOpInterfaceMethods, + NoRegionArguments, SingleBlock]> { + let summary = "step predicate operation"; + let description = [{ + Defines a predicate that must be satisfied between iterations of the loop. + + The predicate can access the value in the previous iteration by using the `verif.old` operation. + + Declares the `MemoryEffectsOpInterface`, which, like the `cf.assert` (MLIR `cf` dialect) + and `bool.assert` (LLZK `bool` dialect) ops, adds a MemWrite affect to model program termination. + }]; + + let regions = (region SizedRegion<1>:$region); + + let extraClassDefinition = [{ + // This side effect models "program termination". Based on + // https://github.com/llvm/llvm-project/blob/f325e4b2d836d6e65a4d0cf3efc6b0996ccf3765/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp#L92-L97 + void $cppClass::getEffects( + ::mlir::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects + ) { + effects.emplace_back(::mlir::MemoryEffects::Write::get()); + } + }]; + + let assemblyFormat = "$region attr-dict"; +} + +def StepYieldOp + : VerifDialectOp<"step.yield", [HasParent<"StepOp">, Terminator]> { + let summary = "step yield operation"; + let description = [{ + Terminator operation for `verif.step` blocks. Takes the boolean value carrying the predicate's result. + }]; + + let arguments = (ins I1:$value); + + let assemblyFormat = "$value attr-dict"; +} + +def OldOp : VerifDialectOp<"old", [Pure, HasAncestor<"::llzk::verif::StepOp">, + AllTypesMatch<["value", "result"]>]> { + let summary = "old operation"; + let description = [{ + This operation allows accessing the value of an expression in the previous iteration of the loop. + + In can only be used within the body of a `verif.step` operation. + }]; + + let arguments = (ins AnyLLZKType:$value); + let results = (outs AnyLLZKType:$result); + + let assemblyFormat = "$value `:` type($result) attr-dict"; +} + #endif // LLZK_VERIF_OPS diff --git a/lib/CAPI/Dialect/Verif.cpp b/lib/CAPI/Dialect/Verif.cpp index d5e4160fe8..eb44bf3e97 100644 --- a/lib/CAPI/Dialect/Verif.cpp +++ b/lib/CAPI/Dialect/Verif.cpp @@ -28,6 +28,10 @@ using namespace llzk::verif; MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Verif, llzk__verif, VerifDialect) +void llzkVerif_attachInterfaces(MlirContext context) { + ::llzk::verif::attachInterfaces(*unwrap(context)); +} + //===----------------------------------------------------------------------===// // ContractOp //===----------------------------------------------------------------------===// @@ -95,3 +99,26 @@ LLZK_DEFINE_OP_BUILD_METHOD( ) ); } + +//===----------------------------------------------------------------------===// +// InvariantOp +//===----------------------------------------------------------------------===// + +LLZK_DEFINE_OP_BUILD_METHOD( + Verif, InvariantOp, MlirStringRef loopName, intptr_t numArgs, MlirType const *types, + MlirLocation const *locs +) { + SmallVector typesSto; + SmallVector locsSto; + auto typesRef = unwrapList(numArgs, types, typesSto); + auto locsRef = unwrapList(numArgs, locs, locsSto); + + return mlirOpBuilderInsert( + builder, + wrap(llzk::create(builder, location, unwrap(loopName), typesRef, locsRef)) + ); +} + +MlirBlock llzkVerif_InvariantOpGetBody(MlirOperation op) { + return wrap(unwrap_cast(op).getBody()); +} diff --git a/lib/Dialect/InitDialects.cpp b/lib/Dialect/InitDialects.cpp index c805397e2b..9de6795db4 100644 --- a/lib/Dialect/InitDialects.cpp +++ b/lib/Dialect/InitDialects.cpp @@ -57,5 +57,6 @@ void registerAllDialects(mlir::DialectRegistry ®istry) { >(); registerInliningExtensions(registry); + verif::registerExtensions(registry); } } // namespace llzk diff --git a/lib/Dialect/Shared/Builders.cpp b/lib/Dialect/Shared/Builders.cpp index 30f3efb18a..943bd34d18 100644 --- a/lib/Dialect/Shared/Builders.cpp +++ b/lib/Dialect/Shared/Builders.cpp @@ -11,6 +11,9 @@ #include "llzk/Dialect/LLZK/IR/Dialect.h" +#include +#include + #include namespace llzk { @@ -268,13 +271,19 @@ Derived &ModuleLikeBuilder::insertConstrainCall( template Derived &ModuleLikeBuilder::insertFreeFunc( - std::string_view funcName, FunctionType type, Location loc + std::string_view funcName, FunctionType type, Location loc, + llvm::function_ref fnBody ) { ensureNoSuchFreeFunc(funcName); OpBuilder opBuilder(this->getBodyRegion()); auto funcDef = opBuilder.create(loc, funcName, type); - (void)funcDef.addEntryBlock(); + auto *block = funcDef.addEntryBlock(); + if (fnBody) { + OpBuilder::InsertionGuard guard(opBuilder); + opBuilder.setInsertionPointToEnd(block); + fnBody(opBuilder); + } freeFuncMap[funcName] = funcDef; return static_cast(*this); diff --git a/lib/Dialect/Verif/IR/Dialect.cpp b/lib/Dialect/Verif/IR/Dialect.cpp index 95adb5af95..f3d96a5ea5 100644 --- a/lib/Dialect/Verif/IR/Dialect.cpp +++ b/lib/Dialect/Verif/IR/Dialect.cpp @@ -12,17 +12,104 @@ #include "llzk/Dialect/LLZK/IR/Versioning.h" #include "llzk/Dialect/Verif/IR/Ops.h" +#include #include +#include +#include +#include +#include +#include +#include #include // TableGen'd implementation files #include "llzk/Dialect/Verif/IR/Dialect.cpp.inc" +using namespace mlir; + +//===------------------------------------------------------------------===// +// InvariantTarget implementations for upstream dialects +//===------------------------------------------------------------------===// + +namespace { + +/// Shared implementation of the `getLabel` method. +static FailureOr getLabelImpl(Operation *op) { + auto attr = dyn_cast_or_null(op->getDiscardableAttr("loop_label")); + if (!attr) { + return failure(); + } + return attr.getValue(); +} + +struct ScfWhileExternalModel : public llzk::verif::InvariantTargetOpInterface::ExternalModel< + ScfWhileExternalModel, scf::WhileOp> { + FailureOr getLabel(Operation *op) const { return getLabelImpl(op); } + + SmallVector getArgumentTypes(Operation *op) const { + auto whileOp = cast(op); + // In the case of `scf.while` we return the 'before' arguments. + // Depending on how the loop is constructed it may not be the most ergonomic + // when it comes to binding the loop arguments in an invariant. The limitation of using these + // arguments is that invariants will have trouble expressing properties of a loop that rely + // on intermediate values passed via the 'after' arguments. The downside of using both + // 'before' and 'after' arguments is that any 'before' argument that is passed to the + // 'after' arguments will require a duplicate binding in the invariant, which is probably + // not very user-friendly and may lead to confusion. + return llvm::map_to_vector(whileOp.getBeforeArguments(), [](auto arg) { + return arg.getType(); + }); + } +}; + +struct ScfForExternalModel : public llzk::verif::InvariantTargetOpInterface::ExternalModel< + ScfForExternalModel, scf::ForOp> { + FailureOr getLabel(Operation *op) const { return getLabelImpl(op); } + + SmallVector getArgumentTypes(Operation *op) const { + auto forOp = cast(op); + // In the case of `scf.for` we return the control values in a fixed order defined by the + // spec language semantics followed by any loop carried values. The order for the control values + // is: lower bound, induction variable, upper bound, step. + SmallVector types; + types.reserve(forOp.getNumRegionIterArgs() + 4); + types.append( + {forOp.getLowerBound().getType(), forOp.getInductionVar().getType(), + forOp.getUpperBound().getType(), forOp.getStep().getType()} + ); + types.append(llvm::map_to_vector(forOp.getInitArgs(), [](auto arg) { return arg.getType(); })); + return types; + } +}; + +/// Dialect extension that attaches the interfaces to upstream ops that promised them. +struct InterfacesExtension + : public DialectExtension { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InterfacesExtension) + + void apply(MLIRContext *context, llzk::verif::VerifDialect *, scf::SCFDialect *) const final { + llzk::verif::attachInterfaces(*context); + } +}; + +} // namespace + //===------------------------------------------------------------------===// // VerifDialect //===------------------------------------------------------------------===// +void llzk::verif::attachInterfaces(MLIRContext &context) { + scf::WhileOp::attachInterface(context); + scf::ForOp::attachInterface(context); +} + +void llzk::verif::registerExtensions(mlir::DialectRegistry ®istry) { + registry.addExtension( + TypeID::get(), std::make_unique() + ); +} + auto llzk::verif::VerifDialect::initialize() -> void { // clang-format off addOperations< @@ -31,4 +118,6 @@ auto llzk::verif::VerifDialect::initialize() -> void { >(); // clang-format on addInterfaces>(); + declarePromisedInterface(); + declarePromisedInterface(); } diff --git a/lib/Dialect/Verif/IR/Ops.cpp b/lib/Dialect/Verif/IR/Ops.cpp index c78eb86aa7..3cf51454d6 100644 --- a/lib/Dialect/Verif/IR/Ops.cpp +++ b/lib/Dialect/Verif/IR/Ops.cpp @@ -33,11 +33,16 @@ #include #include #include +#include #include #include +#include +#include #include +#include + // TableGen'd implementation files #include "llzk/Dialect/Verif/IR/OpInterfaces.cpp.inc" @@ -699,6 +704,13 @@ FailureOr> ContractOp::getFuncTarget(SymbolTableCo ); } +FailureOr> +ContractOp::getTargetOp(SymbolTableCollection &tables) { + return lookupTopLevelSymbol( + tables, getTarget(), getParentOfType(getOperation()), /*reportMissing*/ false + ); +} + FailureOr ContractOp::getSelfValue() { if (failed(getStructTarget()) || getNumArguments() == 0) { return failure(); @@ -1081,4 +1093,161 @@ Operation *IncludeOp::resolveCallable() { return resolveCallableInTable(&tables); } +//===------------------------------------------------------------------===// +// InvariantOp +//===------------------------------------------------------------------===// + +void InvariantOp::build( + OpBuilder &odsBuilder, OperationState &odsState, StringRef loop_name, + ArrayRef loop_arg_types, ArrayRef loop_arg_locs +) { + odsState.getOrAddProperties().loop_name = + odsBuilder.getStringAttr(loop_name); + odsState.getOrAddProperties().loop_arg_types = + odsBuilder.getTypeArrayAttr(loop_arg_types); + auto region = std::make_unique(); + auto &block = region->emplaceBlock(); + block.addArguments(loop_arg_types, loop_arg_locs); + odsState.regions.push_back(std::move(region)); +} + +namespace { +static LogicalResult verifyArgTypes(InvariantTargetOpInterface target, InvariantOp *op) { + auto targetArgTypes = target.getArgumentTypes(); + auto declaredTypes = op->getLoopArgTypes().getValue(); + auto bodyArgTypes = op->getBody()->getArgumentTypes(); + + if (targetArgTypes.size() != declaredTypes.size()) { + return op->emitOpError() << "target has " << targetArgTypes.size() + << " arguments but invariant declared " << declaredTypes.size(); + } + if (bodyArgTypes.size() != declaredTypes.size()) { + return op->emitOpError() << "invariant body has " << targetArgTypes.size() + << " arguments but declared " << declaredTypes.size(); + } + + bool failed = false; + for (auto [n, types] : + llvm::enumerate(llvm::zip_equal(targetArgTypes, bodyArgTypes, declaredTypes))) { + auto [targetType, bodyArgType, declaredType] = types; + + if (targetType != mlir::cast(declaredType).getValue()) { + failed = true; + op->emitOpError() << "target argument #" << n << " expected type " << targetType + << " but invariant declared type " << declaredType; + } + if (bodyArgType != mlir::cast(declaredType).getValue()) { + failed = true; + op->emitOpError() << "invariant argument #" << n << " expected type " << targetType + << " but invariant declared type " << declaredType; + } + } + + return failure(failed); +} +} // namespace + +LogicalResult InvariantOp::verify() { + auto invariantTarget = getTarget(); + if (failed(invariantTarget)) { + return failure(); + } + + return verifyArgTypes(*invariantTarget, this); +} + +ParseResult InvariantOp::parse(OpAsmParser &parser, OperationState &result) { + if (failed(parser.parseKeyword("for"))) { + return failure(); + } + + // Parse the loop label as a symbol. + StringAttr loopNameAttr; + if (parser.parseSymbolName(loopNameAttr)) { + return failure(); + } + result.getOrAddProperties().loop_name = loopNameAttr; + + // Parse the function signature. + bool isVariadic = false; + SmallVector entryArgs; + SmallVector resultAttrs; + SmallVector resultTypes; + + if (function_interface_impl::parseFunctionSignature( + parser, /*allowVariadic*/ false, entryArgs, isVariadic, resultTypes, resultAttrs + )) { + return failure(); + } + assert(isVariadic == false); + // There should be no return types or attributes. + if (!resultTypes.empty() || !resultAttrs.empty()) { + return failure(); + } + + SmallVector argTypes = llvm::map_to_vector(entryArgs, [](auto arg) { return arg.type; }); + result.getOrAddProperties().loop_arg_types = + parser.getBuilder().getTypeArrayAttr(argTypes); + + auto *body = result.addRegion(); + SMLoc loc = parser.getCurrentLocation(); + if (parser.parseRegion( + *body, entryArgs, + /*enableNameShadowing=*/false + )) { + return failure(); + } + + if (body->empty()) { + return parser.emitError(loc, "expected non-empty invariant body"); + } + + return success(); +} + +void InvariantOp::print(OpAsmPrinter &p) { + // Print the name of the invariants's target. + p << " for "; + p.printSymbolName(getLoopName()); + p << "("; + llvm::interleave(getBody()->getArguments(), [&p](auto arg) { + p.printRegionArgument(arg); + }, [&p]() { p << ", "; }); + p << ") "; + // Print the body. + Region &body = getRegion(); + p.printRegion( + body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true + ); +} + +ContractOp InvariantOp::getParentContract() { + return this->getOperation()->getParentOfType(); +} + +FailureOr InvariantOp::getTarget() { + auto target = getParentContract().getTargetOp(); + if (failed(target)) { + return failure(); + } + SmallVector matches; + for (auto invariantTarget : target->get().getLoops()) { + auto targetLabel = invariantTarget.getLabel(); + if (succeeded(targetLabel) && *targetLabel == getLoopName()) { + matches.push_back(invariantTarget); + } + } + + if (matches.size() == 0) { + return emitOpError() << "no invariant target with label \"" << getLoopName() + << "\" found in contract target " << target->get().getNameAttr(); + } + if (matches.size() > 1) { + return emitOpError() << "ambiguous label \"" << getLoopName() << "\" matched " << matches.size() + << " invariant targets in contract target " << target->get().getNameAttr(); + } + return matches[0]; +} + } // namespace llzk::verif diff --git a/test/Dialect/Verif/invariant_inners_fail.llzk b/test/Dialect/Verif/invariant_inners_fail.llzk new file mode 100644 index 0000000000..a206ccc3e4 --- /dev/null +++ b/test/Dialect/Verif/invariant_inners_fail.llzk @@ -0,0 +1,101 @@ +// RUN: llzk-opt -split-input-file -verify-diagnostics %s + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + // expected-error@+1 {{'verif.step' op must have an ancestor of type 'verif.invariant'}} + verif.step { + %2 = arith.constant true + verif.step.yield %2 + } + } +} + +// ----- + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + %0 = felt.const 1 + // expected-error@+1 {{'verif.increases' op must have an ancestor of type 'verif.invariant'}} + verif.increases %0 + } +} + +// ----- + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + %0 = felt.const 1 + // expected-error@+1 {{'verif.decreases' op must have an ancestor of type 'verif.invariant'}} + verif.decreases %0 + } +} + +// ----- + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + %0 = felt.const 1 + // expected-error@+1 {{'verif.old' op must have an ancestor of type 'verif.step'}} + %1 = verif.old %0 : !felt.type + } +} + +// ----- + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + %2 = arith.constant true + // expected-error@+1 {{'verif.step.yield' op expects parent op 'verif.step'}} + verif.step.yield %2 + } +} diff --git a/test/Dialect/Verif/invariant_inners_pass.llzk b/test/Dialect/Verif/invariant_inners_pass.llzk new file mode 100644 index 0000000000..d00cdf9fea --- /dev/null +++ b/test/Dialect/Verif/invariant_inners_pass.llzk @@ -0,0 +1,174 @@ +// RUN: llzk-opt -split-input-file %s 2>&1 | FileCheck --enable-var-scope %s + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + scf.for %iv = %c0 to %c10 step %c1 { + scf.yield + } {loop_label = "loopA"} + + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + verif.invariant for @loopA(%lb: index, %iv: index, %ub: index, %step: index) { + %0 = cast.tofelt %iv : index + verif.increases %0 + } + } +} + +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @Top { +// CHECK-NEXT: function.def @compute() -> !struct.type<@Top> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@Top> +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[C10:[0-9a-zA-Z_\.]+]] = arith.constant 10 : index +// CHECK-NEXT: %[[C1:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[IV:[0-9a-zA-Z_\.]+]] = %[[C0]] to %[[C10]] step %[[C1]] { +// CHECK-NEXT: } {loop_label = "loopA"} +// CHECK-NEXT: function.return %[[VAL_0]] : !struct.type<@Top> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: verif.contract @Foo for @Top (%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) { +// CHECK-NEXT: verif.invariant for @loopA( +// CHECK-SAME: %[[VAL_3:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_4:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_5:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_6:[0-9a-zA-Z_\.]+]]: index) { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = cast.tofelt %[[VAL_4]] : index +// CHECK-NEXT: verif.increases %[[VAL_7]] +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// ----- + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + scf.for %iv = %c0 to %c10 step %c1 { + scf.yield + } {loop_label = "loopA"} + + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + verif.invariant for @loopA(%lb: index, %iv: index, %ub: index, %step: index) { + %0 = arith.subi %ub, %iv : index + %1 = cast.tofelt %0 : index + verif.decreases %1 + } + } +} + +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @Top { +// CHECK-NEXT: function.def @compute() -> !struct.type<@Top> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@Top> +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[C10:[0-9a-zA-Z_\.]+]] = arith.constant 10 : index +// CHECK-NEXT: %[[C1:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[IV:[0-9a-zA-Z_\.]+]] = %[[C0]] to %[[C10]] step %[[C1]] { +// CHECK-NEXT: } {loop_label = "loopA"} +// CHECK-NEXT: function.return %[[VAL_0]] : !struct.type<@Top> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: verif.contract @Foo for @Top (%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) { +// CHECK-NEXT: verif.invariant for @loopA( +// CHECK-SAME: %[[VAL_3:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_4:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_5:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_6:[0-9a-zA-Z_\.]+]]: index) { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = arith.subi %[[VAL_5]], %[[VAL_4]] : index +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = cast.tofelt %[[VAL_7]] : index +// CHECK-NEXT: verif.decreases %[[VAL_8]] +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// ----- + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + scf.for %iv = %c0 to %c10 step %c1 { + scf.yield + } {loop_label = "loopA"} + + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + verif.invariant for @loopA(%lb: index, %iv: index, %ub: index, %step: index) { + verif.step { + %0 = verif.old %iv : index + %1 = arith.subi %iv, %0 : index + %2 = arith.cmpi eq, %1, %step : index + verif.step.yield %2 + } + } + } +} + +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @Top { +// CHECK-NEXT: function.def @compute() -> !struct.type<@Top> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@Top> +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[C10:[0-9a-zA-Z_\.]+]] = arith.constant 10 : index +// CHECK-NEXT: %[[C1:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[IV:[0-9a-zA-Z_\.]+]] = %[[C0]] to %[[C10]] step %[[C1]] { +// CHECK-NEXT: } {loop_label = "loopA"} +// CHECK-NEXT: function.return %[[VAL_0]] : !struct.type<@Top> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: verif.contract @Foo for @Top (%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) { +// CHECK-NEXT: verif.invariant for @loopA( +// CHECK-SAME: %[[VAL_3:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_4:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_5:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_6:[0-9a-zA-Z_\.]+]]: index) { +// CHECK-NEXT: verif.step { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = verif.old %[[VAL_4]] : index +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = arith.subi %[[VAL_4]], %[[VAL_7]] : index +// CHECK-NEXT: %[[VAL_9:[0-9a-zA-Z_\.]+]] = arith.cmpi eq, %[[VAL_8]], %[[VAL_6]] : index +// CHECK-NEXT: verif.step.yield %[[VAL_9]] +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/test/Dialect/Verif/invariants_fail.llzk b/test/Dialect/Verif/invariants_fail.llzk new file mode 100644 index 0000000000..d8fc6dd88d --- /dev/null +++ b/test/Dialect/Verif/invariants_fail.llzk @@ -0,0 +1,105 @@ +// RUN: llzk-opt -split-input-file -verify-diagnostics %s + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + scf.while : () -> () { + %true = arith.constant true + scf.condition(%true) + } do { + scf.yield + } + + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + // expected-error@+1 {{'verif.invariant' op no invariant target with label "loopA" found in contract target "Top"}} + verif.invariant for @loopA() { + %true = arith.constant true + verif.require_compute %true + } + } +} + +// ----- + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + scf.while : () -> () { + %true = arith.constant true + scf.condition(%true) + } do { + scf.yield + } attributes {loop_label = "loopB"} + + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + // expected-error@+1 {{'verif.invariant' op no invariant target with label "loopA" found in contract target "Top"}} + verif.invariant for @loopA() { + %true = arith.constant true + verif.require_compute %true + } + } +} + +// ----- + +module attributes {llzk.lang} { + function.def @Bar() attributes {function.allow_witness} { + %c0 = felt.const 4 + %0 = scf.while (%1 = %c0) : (!felt.type) -> (!felt.type) { + %true = arith.constant true + scf.condition(%true) %1 : !felt.type + } do { + ^bb0(%2: !felt.type): + scf.yield %2 : !felt.type + } attributes {loop_label = "loopA"} + function.return + } + + verif.contract @Foo for @Bar () { + // expected-error@+1 {{'verif.invariant' op target has 1 arguments but invariant declared 2}} + verif.invariant for @loopA(%0: !felt.type, %1: !felt.type) { + %true = arith.constant true + verif.require_compute %true + } + } +} + +// ----- + +module attributes {llzk.lang} { + function.def @Bar() attributes {function.allow_witness} { + %c0 = felt.const 4 + %0 = scf.while (%1 = %c0) : (!felt.type) -> (!felt.type) { + %true = arith.constant true + scf.condition(%true) %1 : !felt.type + } do { + ^bb0(%2: !felt.type): + scf.yield %2 : !felt.type + } attributes {loop_label = "loopA"} + function.return + } + + verif.contract @Foo for @Bar () { + // expected-error@+1 {{'verif.invariant' op target argument #0 expected type '!felt.type' but invariant declared type !array.type<1 x !felt.type>}} + verif.invariant for @loopA(%0: !array.type<1 x !felt.type>) { + %true = arith.constant true + verif.require_compute %true + } + } +} diff --git a/test/Dialect/Verif/invariants_pass.llzk b/test/Dialect/Verif/invariants_pass.llzk new file mode 100644 index 0000000000..9a193dd19d --- /dev/null +++ b/test/Dialect/Verif/invariants_pass.llzk @@ -0,0 +1,248 @@ +// RUN: llzk-opt -split-input-file %s 2>&1 | FileCheck --enable-var-scope %s + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + scf.while : () -> () { + %true = arith.constant true + scf.condition(%true) + } do { + scf.yield + } attributes {loop_label = "loopA"} + + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + verif.invariant for @loopA() { + %true = arith.constant true + verif.require_compute %true + } + } +} + +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @Top { +// CHECK-NEXT: function.def @compute() -> !struct.type<@Top> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@Top> +// CHECK-NEXT: scf.while : () -> () { +// CHECK-NEXT: %true = arith.constant true +// CHECK-NEXT: scf.condition(%true) +// CHECK-NEXT: } do { +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } attributes {loop_label = "loopA"} +// CHECK-NEXT: function.return %[[VAL_0]] : !struct.type<@Top> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: verif.contract @Foo for @Top (%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) { +// CHECK-NEXT: verif.invariant for @loopA() { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: verif.require_compute %[[VAL_3]] +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// ----- + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + scf.for %iv = %c0 to %c10 step %c1 { + scf.yield + } {loop_label = "loopA"} + + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + verif.invariant for @loopA(%lb : index, %iv : index, %ub : index, %step : index) { + %true = arith.constant true + verif.require_compute %true + } + } +} + +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @Top { +// CHECK-NEXT: function.def @compute() -> !struct.type<@Top> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@Top> +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[C10:[0-9a-zA-Z_\.]+]] = arith.constant 10 : index +// CHECK-NEXT: %[[C1:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: scf.for %[[IV:[0-9a-zA-Z_\.]+]] = %[[C0]] to %[[C10]] step %[[C1]] { +// CHECK-NEXT: } {loop_label = "loopA"} +// CHECK-NEXT: function.return %[[VAL_0]] : !struct.type<@Top> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: verif.contract @Foo for @Top (%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) { +// CHECK-NEXT: verif.invariant for @loopA( +// CHECK-SAME: %[[VAL_3:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_4:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_5:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_6:[0-9a-zA-Z_\.]+]]: index) { +// CHECK-NEXT: %[[VAL_7:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: verif.require_compute %[[VAL_7]] +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// ----- + +module attributes {llzk.lang} { + function.def @Bar() attributes {function.allow_witness} { + scf.while : () -> () { + %true = arith.constant true + scf.condition(%true) + } do { + scf.yield + } attributes {loop_label = "loopA"} + function.return + } + + verif.contract @Foo for @Bar () { + verif.invariant for @loopA() { + %true = arith.constant true + verif.require_compute %true + } + } +} + +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @Bar() attributes {function.allow_witness} { +// CHECK-NEXT: scf.while : () -> () { +// CHECK-NEXT: %true = arith.constant true +// CHECK-NEXT: scf.condition(%true) +// CHECK-NEXT: } do { +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } attributes {loop_label = "loopA"} +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: verif.contract @Foo for @Bar () { +// CHECK-NEXT: verif.invariant for @loopA() { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: verif.require_compute %[[VAL_3]] +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// ----- + +module attributes {llzk.lang} { + function.def @Bar() attributes {function.allow_witness} { + %c0 = felt.const 4 + %0 = scf.while (%1 = %c0) : (!felt.type) -> (!felt.type) { + %true = arith.constant true + scf.condition(%true) %1 : !felt.type + } do { + ^bb0(%2: !felt.type): + scf.yield %2 : !felt.type + } attributes {loop_label = "loopA"} + function.return + } + + verif.contract @Foo for @Bar () { + verif.invariant for @loopA(%0: !felt.type) { + %true = arith.constant true + verif.require_compute %true + } + } +} + +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: function.def @Bar() attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = felt.const 4 +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = scf.while (%[[VAL_2:[0-9a-zA-Z_\.]+]] = %[[VAL_0]]) : (!felt.type) -> !felt.type { +// CHECK-NEXT: %[[VAL_3:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: scf.condition(%[[VAL_3]]) +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[VAL_4:[0-9a-zA-Z_\.]+]]: !felt.type): +// CHECK-NEXT: scf.yield %[[VAL_4]] : !felt.type +// CHECK-NEXT: } attributes {loop_label = "loopA"} +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: verif.contract @Foo for @Bar () { +// CHECK-NEXT: verif.invariant for @loopA(%[[VAL_5:[0-9a-zA-Z_\.]+]]: !felt.type) { +// CHECK-NEXT: %[[VAL_6:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: verif.require_compute %[[VAL_6]] +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +// ----- + +module attributes {llzk.lang} { + struct.def @Top { + function.def @compute() -> !struct.type<@Top> { + %self = struct.new : !struct.type<@Top> + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + %0 = felt.const 5 + %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%2 = %0) -> !felt.type { + scf.yield %2 : !felt.type + } {loop_label = "loopA"} + + function.return %self : !struct.type<@Top> + } + function.def @constrain(%self: !struct.type<@Top>) { + function.return + } + } + + verif.contract @Foo for @Top (%self: !struct.type<@Top>) { + verif.invariant for @loopA(%lb: index, %iv: index, %ub: index, %step: index, %extra: !felt.type) { + %true = arith.constant true + verif.require_compute %true + } + } +} + +// CHECK-LABEL: module attributes {llzk.lang} { +// CHECK-NEXT: struct.def @Top { +// CHECK-NEXT: function.def @compute() -> !struct.type<@Top> attributes {function.allow_witness} { +// CHECK-NEXT: %[[VAL_0:[0-9a-zA-Z_\.]+]] = struct.new : <@Top> +// CHECK-NEXT: %[[C0:[0-9a-zA-Z_\.]+]] = arith.constant 0 : index +// CHECK-NEXT: %[[C10:[0-9a-zA-Z_\.]+]] = arith.constant 10 : index +// CHECK-NEXT: %[[C1:[0-9a-zA-Z_\.]+]] = arith.constant 1 : index +// CHECK-NEXT: %[[VAL_1:[0-9a-zA-Z_\.]+]] = felt.const 5 +// CHECK-NEXT: %[[VAL_2:[0-9a-zA-Z_\.]+]] = scf.for %[[IV:[0-9a-zA-Z_\.]+]] = %[[C0]] to %[[C10]] step %[[C1]] +// CHECK-SAME: iter_args(%[[VAL_3:[0-9a-zA-Z_\.]+]] = %[[VAL_1]]) -> (!felt.type) { +// CHECK-NEXT: scf.yield %[[VAL_3]] : !felt.type +// CHECK-NEXT: } {loop_label = "loopA"} +// CHECK-NEXT: function.return %[[VAL_0]] : !struct.type<@Top> +// CHECK-NEXT: } +// CHECK-NEXT: function.def @constrain(%[[VAL_1:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) attributes {function.allow_constraint} { +// CHECK-NEXT: function.return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: verif.contract @Foo for @Top (%[[VAL_2:[0-9a-zA-Z_\.]+]]: !struct.type<@Top>) { +// CHECK-NEXT: verif.invariant for @loopA( +// CHECK-SAME: %[[VAL_3:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_4:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_5:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_6:[0-9a-zA-Z_\.]+]]: index, +// CHECK-SAME: %[[VAL_7:[0-9a-zA-Z_\.]+]]: !felt.type) { +// CHECK-NEXT: %[[VAL_8:[0-9a-zA-Z_\.]+]] = arith.constant true +// CHECK-NEXT: verif.require_compute %[[VAL_8]] +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/unittests/CAPI/Dialect/Verif.cpp b/unittests/CAPI/Dialect/Verif.cpp index c9eb033b4e..e13ce26db5 100644 --- a/unittests/CAPI/Dialect/Verif.cpp +++ b/unittests/CAPI/Dialect/Verif.cpp @@ -11,6 +11,8 @@ #include "../CAPITestBase.h" +#include "llzk/Dialect/Bool/IR/Ops.h" +#include "llzk/Dialect/Cast/IR/Ops.h" #include "llzk/Dialect/Function/IR/Ops.h" #include "llzk/Dialect/Shared/Builders.h" #include "llzk/Dialect/Verif/IR/Ops.h" @@ -20,9 +22,16 @@ #include #include +#include +#include +#include #include +#include #include +#include + +#include // Include the auto-generated tests #include "llzk/Dialect/Verif/IR/Dialect.capi.test.cpp.inc" @@ -58,13 +67,14 @@ static MlirAttribute createEmptyFunctionTypeAttr(MlirContext ctx) { } static mlir::OwningOpRef createModuleWithTargetFunc( - const CAPITest &test, MlirOpBuilder builder, MlirLocation location, llvm::StringRef name + const CAPITest &test, MlirOpBuilder builder, MlirLocation location, llvm::StringRef name, + llvm::function_ref fnBody = nullptr ) { auto newModule = test.cppNewModuleAndSetInsertionPoint(builder, location); llzk::ModuleBuilder modBuilder(newModule.get()); modBuilder.insertFreeFunc( name, mlir::FunctionType::get(unwrap(test.context), mlir::TypeRange {}, mlir::TypeRange {}), - unwrap(location) + unwrap(location), fnBody ); unwrap(builder)->setInsertionPointToStart(newModule->getBody()); return newModule; @@ -392,3 +402,192 @@ std::unique_ptr RequireConstrainOpBuildFuncHe }; return std::make_unique(); } + +TEST_F(CAPITest, llzkVerifInvariantOpBuild) { + MlirOpBuilder builder = mlirOpBuilderCreate(context); + MlirLocation location = mlirLocationUnknownGet(context); + auto module = parseSourceString( + R"mlir( +module attributes {llzk.lang} { + function.def @target() attributes {function.allow_witness} { + scf.while : () -> () { + %true = arith.constant true + scf.condition(%true) + } do { + scf.yield + } attributes {loop_label = "loopA"} + function.return + } +} +)mlir", + mlir::ParserConfig(unwrap(context)) + ); + ASSERT_TRUE(module); + unwrap(builder)->setInsertionPointToEnd(module->getBody()); + + auto contract = createCppContract(builder, location, "ContractUnderTest", "target"); + unwrap(builder)->setInsertionPointToStart(&contract.getBody().front()); + + auto invariant = llzkVerif_InvariantOpBuild( + builder, location, mlirStringRefCreateFromCString("loopA"), 0, nullptr, nullptr + ); + EXPECT_TRUE(mlirOperationVerify(invariant)); + + mlirOpBuilderDestroy(builder); +} + +TEST_F(CAPITest, llzkVerifInvariantOpBuildWithArgs) { + MlirOpBuilder builder = mlirOpBuilderCreate(context); + MlirLocation location = mlirLocationUnknownGet(context); + auto module = parseSourceString( + R"mlir( +module attributes {llzk.lang} { + function.def @target() attributes {function.allow_witness} { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + scf.for %iv = %c0 to %c10 step %c1 { + scf.yield + } {loop_label = "loopA"} + function.return + } +} +)mlir", + mlir::ParserConfig(unwrap(context)) + ); + ASSERT_TRUE(module); + unwrap(builder)->setInsertionPointToEnd(module->getBody()); + size_t argCount = 4; + llvm::SmallVector argTypes(argCount, mlirIndexTypeGet(context)); + llvm::SmallVector argLocs(argCount, location); + + auto contract = createCppContract(builder, location, "ContractUnderTest", "target"); + unwrap(builder)->setInsertionPointToStart(&contract.getBody().front()); + + auto invariant = llzkVerif_InvariantOpBuild( + builder, location, mlirStringRefCreateFromCString("loopA"), static_cast(argCount), + argTypes.data(), argLocs.data() + ); + EXPECT_TRUE(mlirOperationVerify(invariant)); + + mlirOpBuilderDestroy(builder); +} + +namespace { +struct VerifInvariantInnerOpBuildBase { + mlir::OwningOpRef parentModule; + + mlir::ValueRange + prepareInsertionSite(const CAPITest &testClass, MlirOpBuilder builder, MlirLocation location) { + this->parentModule = parseSourceString( + R"mlir( +module attributes {llzk.lang} { + function.def @target() attributes {function.allow_witness} { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + scf.for %iv = %c0 to %c10 step %c1 { + scf.yield + } {loop_label = "loopA"} + function.return + } +} +)mlir", + mlir::ParserConfig(unwrap(testClass.context)) + ); + unwrap(builder)->setInsertionPointToEnd(parentModule->getBody()); + + size_t argCount = 4; + llvm::SmallVector argTypes(argCount, mlirIndexTypeGet(testClass.context)); + llvm::SmallVector argLocs(argCount, location); + + auto contract = createCppContract(builder, location, "ContractUnderTest", "target"); + unwrap(builder)->setInsertionPointToStart(&contract.getBody().front()); + + auto invariant = llzkVerif_InvariantOpBuild( + builder, location, mlirStringRefCreateFromCString("loopA"), static_cast(argCount), + argTypes.data(), argLocs.data() + ); + auto *invariantBody = cast(unwrap(invariant)).getBody(); + unwrap(builder)->setInsertionPointToStart(invariantBody); + return invariantBody->getArguments(); + } +}; +} // namespace + +std::unique_ptr IncreasesOpBuildFuncHelper::get() { + struct Impl : public IncreasesOpBuildFuncHelper, VerifInvariantInnerOpBuildBase { + MlirOperation + callBuild(const CAPITest &testClass, MlirOpBuilder builder, MlirLocation location) override { + auto args = prepareInsertionSite(testClass, builder, location); + auto step = unwrap(builder)->create(unwrap(location), args[3]); + return llzkVerif_IncreasesOpBuild(builder, location, wrap(step->getResult(0))); + } + }; + return std::make_unique(); +} + +std::unique_ptr DecreasesOpBuildFuncHelper::get() { + struct Impl : public DecreasesOpBuildFuncHelper, VerifInvariantInnerOpBuildBase { + MlirOperation + callBuild(const CAPITest &testClass, MlirOpBuilder builder, MlirLocation location) override { + auto args = prepareInsertionSite(testClass, builder, location); + auto step = unwrap(builder)->create(unwrap(location), args[3]); + return llzkVerif_DecreasesOpBuild(builder, location, wrap(step->getResult(0))); + } + }; + return std::make_unique(); +} + +std::unique_ptr StepOpBuildFuncHelper::get() { + struct Impl : public StepOpBuildFuncHelper, VerifInvariantInnerOpBuildBase { + MlirOperation + callBuild(const CAPITest &testClass, MlirOpBuilder builder, MlirLocation location) override { + prepareInsertionSite(testClass, builder, location); + auto op = llzkVerif_StepOpBuild(builder, location); + auto *region = unwrap(llzkVerif_StepOpGetRegion(op)); + unwrap(builder)->setInsertionPointToEnd(®ion->emplaceBlock()); + auto trueOp = unwrap(builder)->create( + unwrap(location), 1, unwrap(builder)->getI1Type() + ); + unwrap(builder)->create(unwrap(location), trueOp); + return op; + } + }; + return std::make_unique(); +} + +std::unique_ptr StepYieldOpBuildFuncHelper::get() { + struct Impl : public StepYieldOpBuildFuncHelper, VerifInvariantInnerOpBuildBase { + MlirOperation + callBuild(const CAPITest &testClass, MlirOpBuilder builder, MlirLocation location) override { + prepareInsertionSite(testClass, builder, location); + auto stepOp = unwrap(builder)->create(unwrap(location)); + unwrap(builder)->setInsertionPointToEnd(&stepOp.getRegion().emplaceBlock()); + auto trueOp = unwrap(builder)->create( + unwrap(location), 1, unwrap(builder)->getI1Type() + ); + return llzkVerif_StepYieldOpBuild(builder, location, wrap(trueOp.getResult())); + } + }; + return std::make_unique(); +} + +std::unique_ptr OldOpBuildFuncHelper::get() { + struct Impl : public OldOpBuildFuncHelper, VerifInvariantInnerOpBuildBase { + MlirOperation + callBuild(const CAPITest &testClass, MlirOpBuilder builder, MlirLocation location) override { + auto args = prepareInsertionSite(testClass, builder, location); + auto stepOp = unwrap(builder)->create(unwrap(location)); + unwrap(builder)->setInsertionPointToEnd(&stepOp.getRegion().emplaceBlock()); + auto op = llzkVerif_OldOpBuild(builder, location, wrap(args[1])); + + auto trueOp = unwrap(builder)->create( + unwrap(location), 1, unwrap(builder)->getI1Type() + ); + unwrap(builder)->create(unwrap(location), trueOp); + return op; + } + }; + return std::make_unique(); +} From 957058fd7e16a05e90eb61186f211383924da4a2 Mon Sep 17 00:00:00 2001 From: "Cyne Jarvis J. Zarceno" Date: Mon, 22 Jun 2026 05:10:01 +0800 Subject: [PATCH 11/12] Use named aux assignment visit states --- lib/Transforms/LLZKPolyLoweringPass.cpp | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/lib/Transforms/LLZKPolyLoweringPass.cpp b/lib/Transforms/LLZKPolyLoweringPass.cpp index 3ee780e378..abcaacdfbf 100644 --- a/lib/Transforms/LLZKPolyLoweringPass.cpp +++ b/lib/Transforms/LLZKPolyLoweringPass.cpp @@ -57,6 +57,12 @@ struct AuxAssignment { Value auxValue; }; +enum class AuxAssignmentVisitState : uint8_t { + Unvisited, + Visiting, + Done, +}; + class PassImpl : public llzk::impl::PolyLoweringPassBase { using Base = PolyLoweringPassBase; using Base::Base; @@ -116,26 +122,27 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { } LogicalResult visitAuxAssignment( - unsigned idx, ArrayRef> deps, SmallVectorImpl &visitState, - SmallVectorImpl &ordered, ArrayRef auxAssignments + unsigned idx, ArrayRef> deps, + SmallVectorImpl &visitState, SmallVectorImpl &ordered, + ArrayRef auxAssignments ) const { - if (visitState[idx] == 2) { + if (visitState[idx] == AuxAssignmentVisitState::Done) { return success(); } - if (visitState[idx] == 1) { + if (visitState[idx] == AuxAssignmentVisitState::Visiting) { return emitError(auxAssignments[idx].computedValue.getLoc()) << "poly lowering generated cyclic auxiliary dependency involving @" << auxAssignments[idx].auxMemberName; } - visitState[idx] = 1; + visitState[idx] = AuxAssignmentVisitState::Visiting; // Emit prerequisite aux writes before the aux writes that read them. for (unsigned dep : deps[idx]) { if (failed(visitAuxAssignment(dep, deps, visitState, ordered, auxAssignments))) { return failure(); } } - visitState[idx] = 2; + visitState[idx] = AuxAssignmentVisitState::Done; ordered.push_back(idx); return success(); } @@ -163,7 +170,9 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { ); } - SmallVector visitState(auxAssignments.size(), 0); + SmallVector visitState( + auxAssignments.size(), AuxAssignmentVisitState::Unvisited + ); for (unsigned idx = 0, e = auxAssignments.size(); idx < e; ++idx) { if (failed(visitAuxAssignment(idx, deps, visitState, ordered, auxAssignments))) { return failure(); From 12772acbd091a89109669e0e767494c824190077 Mon Sep 17 00:00:00 2001 From: "Cyne Jarvis J. Zarceno" Date: Tue, 23 Jun 2026 06:31:52 +0800 Subject: [PATCH 12/12] Address poly lowering review nits --- lib/Transforms/LLZKPolyLoweringPass.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/lib/Transforms/LLZKPolyLoweringPass.cpp b/lib/Transforms/LLZKPolyLoweringPass.cpp index 78ad19f045..c6f23e432a 100644 --- a/lib/Transforms/LLZKPolyLoweringPass.cpp +++ b/lib/Transforms/LLZKPolyLoweringPass.cpp @@ -76,6 +76,7 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { }); } + /// Records a dependency from the current aux assignment to a prerequisite. void addAuxDependency( unsigned dep, unsigned owner, DenseSet &seenDeps, SmallVectorImpl &deps ) const { @@ -87,6 +88,7 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { } } + /// Collects aux assignments that must be written before the given value can be rebuilt. void collectAuxDependencies( Value val, unsigned owner, const DenseMap &auxValueToIndex, const llvm::StringMap &auxNameToIndex, DenseSet &visitedValues, @@ -109,18 +111,16 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { } } - Operation *defOp = val.getDefiningOp(); - if (!defOp) { - return; - } - - for (Value operand : defOp->getOperands()) { - collectAuxDependencies( - operand, owner, auxValueToIndex, auxNameToIndex, visitedValues, seenDeps, deps - ); + if (Operation *defOp = val.getDefiningOp()) { + for (Value operand : defOp->getOperands()) { + collectAuxDependencies( + operand, owner, auxValueToIndex, auxNameToIndex, visitedValues, seenDeps, deps + ); + } } } + /// Visits aux assignments depth-first so dependencies are emitted before users. LogicalResult visitAuxAssignment( unsigned idx, ArrayRef> deps, SmallVectorImpl &visitState, SmallVectorImpl &ordered, @@ -147,6 +147,7 @@ class PassImpl : public llzk::impl::PolyLoweringPassBase { return success(); } + /// Produces a topological write order for generated aux assignments. LogicalResult orderAuxAssignments( ArrayRef auxAssignments, SmallVectorImpl &ordered ) const {